enhance: add an unify vector index config checker (#36844)

issue: #34298

Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
pull/37180/head
foxspy 2024-10-28 10:11:37 +08:00 committed by GitHub
parent eeb67a3845
commit d7b2ffe5aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
93 changed files with 1216 additions and 785 deletions

View File

@ -33,6 +33,7 @@ func NewDiskANNIndex(metricType MetricType) Index {
return &diskANNIndex{
baseIndex: baseIndex{
metricType: metricType,
indexType: DISKANN,
},
}
}

View File

@ -33,6 +33,7 @@ func NewFlatIndex(metricType MetricType) Index {
return flatIndex{
baseIndex: baseIndex{
metricType: metricType,
indexType: Flat,
},
}
}
@ -54,6 +55,7 @@ func NewBinFlatIndex(metricType MetricType) Index {
return binFlatIndex{
baseIndex: baseIndex{
metricType: metricType,
indexType: BinFlat,
},
}
}

View File

@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
grpcproxy "github.com/milvus-io/milvus/internal/distributed/proxy"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -50,6 +51,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
}
func (n *Proxy) Prepare() error {
indexparamcheck.ValidateParamTable()
return n.svr.Prepare()
}

View File

@ -20,6 +20,79 @@
#include "index/IndexFactory.h"
#include "pb/index_cgo_msg.pb.h"
CStatus
ValidateIndexParams(const char* index_type,
enum CDataType data_type,
const uint8_t* serialized_index_params,
const uint64_t length) {
try {
auto index_params =
std::make_unique<milvus::proto::indexcgo::IndexParams>();
auto res =
index_params->ParseFromArray(serialized_index_params, length);
AssertInfo(res, "Unmarshall index params failed");
knowhere::Json json;
for (size_t i = 0; i < index_params->params_size(); i++) {
auto& param = index_params->params(i);
json[param.key()] = param.value();
}
milvus::DataType dataType(static_cast<milvus::DataType>(data_type));
knowhere::Status status;
std::string error_msg;
if (dataType == milvus::DataType::VECTOR_BINARY) {
status = knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_FLOAT) {
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_BFLOAT16) {
status = knowhere::IndexStaticFaced<knowhere::bf16>::ConfigCheck(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_FLOAT16) {
status = knowhere::IndexStaticFaced<knowhere::fp16>::ConfigCheck(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) {
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
error_msg);
} else {
status = knowhere::Status::invalid_args;
}
CStatus cStatus;
if (status == knowhere::Status::success) {
cStatus.error_code = milvus::Success;
cStatus.error_msg = "";
} else {
cStatus.error_code = milvus::ConfigInvalid;
cStatus.error_msg = strdup(error_msg.c_str());
}
return cStatus;
} catch (std::exception& e) {
auto cStatus = CStatus();
cStatus.error_code = milvus::UnexpectedError;
cStatus.error_msg = strdup(e.what());
return cStatus;
}
}
int
GetIndexListSize() {
return knowhere::IndexFactory::Instance().GetIndexFeatures().size();

View File

@ -17,6 +17,12 @@ extern "C" {
#include <stdbool.h>
#include "common/type_c.h"
CStatus
ValidateIndexParams(const char* index_type,
enum CDataType data_type,
const uint8_t* index_params,
const uint64_t length);
int
GetIndexListSize();

View File

@ -34,11 +34,11 @@ import (
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/workerpb"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"

View File

@ -28,11 +28,11 @@ import (
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/paramtable"

View File

@ -42,9 +42,9 @@ import (
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/workerpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -620,13 +620,13 @@ func TestServer_AlterIndex(t *testing.T) {
s.stateCode.Store(commonpb.StateCode_Healthy)
t.Run("mmap_unsupported", func(t *testing.T) {
indexParams[0].Value = indexparamcheck.IndexRaftCagra
indexParams[0].Value = "GPU_CAGRA"
resp, err := s.AlterIndex(ctx, req)
assert.NoError(t, err)
assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
indexParams[0].Value = indexparamcheck.IndexFaissIvfFlat
indexParams[0].Value = "IVF_FLAT"
})
t.Run("param_value_invalied", func(t *testing.T) {

View File

@ -28,13 +28,13 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
@ -475,25 +475,18 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
if err := fillDimension(field, indexParams); err != nil {
return err
}
} else {
// used only for checker, should be deleted after checking
indexParams[IsSparseKey] = "true"
}
if err := checker.CheckValidDataType(field); err != nil {
if err := checker.CheckValidDataType(indexType, field); err != nil {
log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String()))
return err
}
if err := checker.CheckTrain(indexParams); err != nil {
if err := checker.CheckTrain(field.DataType, indexParams); err != nil {
log.Info("create index with invalid parameters", zap.Error(err))
return err
}
if isSparse {
delete(indexParams, IsSparseKey)
}
return nil
}

View File

@ -35,9 +35,9 @@ import (
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"

View File

@ -40,6 +40,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
@ -48,7 +49,6 @@ import (
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/contextutil"
"github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"

View File

@ -29,11 +29,11 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

View File

@ -24,8 +24,8 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

View File

@ -52,12 +52,12 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/segments/state"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/cgo"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metautil"

View File

@ -33,11 +33,11 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/initcore"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/contextutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"

View File

@ -29,12 +29,12 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/contextutil"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"

View File

@ -9,11 +9,11 @@ type AUTOINDEXChecker struct {
baseChecker
}
func (c *AUTOINDEXChecker) CheckTrain(params map[string]string) error {
func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
return nil
}
func (c *AUTOINDEXChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c *AUTOINDEXChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
return nil
}

View File

@ -19,29 +19,17 @@ package indexparamcheck
import (
"fmt"
"math"
"strings"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type baseChecker struct{}
func (c baseChecker) CheckTrain(params map[string]string) error {
// vector dimension should be checked on collection creation. this is just some basic check
isSparse := false
if val, exist := params[common.IsSparseKey]; exist {
val = strings.ToLower(val)
if val != "true" && val != "false" {
return fmt.Errorf("invalid is_sparse value: %s, must be true or false", val)
}
if val == "true" {
isSparse = true
}
}
if isSparse {
func (c baseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if typeutil.IsSparseFloatVectorType(dataType) {
if !CheckStrByValues(params, Metric, SparseMetrics) {
return fmt.Errorf("metric type not found or not supported for sparse float vectors, supported: %v", SparseMetrics)
}
@ -55,13 +43,13 @@ func (c baseChecker) CheckTrain(params map[string]string) error {
}
// CheckValidDataType check whether the field data type is supported for the index type
func (c baseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c baseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
return nil
}
func (c baseChecker) SetDefaultMetricTypeIfNotExist(m map[string]string, dType schemapb.DataType) {}
func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {}
func (c baseChecker) StaticCheck(params map[string]string) error {
func (c baseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
return errors.New("unsupported index type")
}

View File

@ -21,7 +21,7 @@ func Test_baseChecker_CheckTrain(t *testing.T) {
}
sparseParamsWithoutDim := map[string]string{
Metric: metric.IP,
common.IsSparseKey: "tRue",
common.IsSparseKey: "True",
}
sparseParamsWrongMetric := map[string]string{
Metric: metric.L2,
@ -42,9 +42,15 @@ func Test_baseChecker_CheckTrain(t *testing.T) {
{badSparseParams, false},
}
c := newBaseChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "HNSW"
var err error
if test.params[common.IsSparseKey] == "True" {
err = c.CheckTrain(schemapb.DataType_SparseFloatVector, test.params)
} else {
err = c.CheckTrain(schemapb.DataType_FloatVector, test.params)
}
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -115,7 +121,7 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) {
c := newBaseChecker()
for _, test := range cases {
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
err := c.CheckValidDataType(fieldSchema)
err := c.CheckValidDataType("FLAT", fieldSchema)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -126,5 +132,5 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) {
func Test_baseChecker_StaticCheck(t *testing.T) {
// TODO
assert.Error(t, newBaseChecker().StaticCheck(nil))
assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, nil))
}

View File

@ -0,0 +1,19 @@
package indexparamcheck
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
type binFlatChecker struct {
binaryVectorBaseChecker
}
func (c binFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
return c.binaryVectorBaseChecker.CheckTrain(dataType, params)
}
func (c binFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
return c.staticCheck(params)
}
func newBinFlatChecker() IndexChecker {
return &binFlatChecker{}
}

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -64,9 +65,10 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) {
{p7, true},
}
c := newBinFlatChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "BINFLAT"
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -134,10 +136,10 @@ func Test_binFlatChecker_CheckValidDataType(t *testing.T) {
},
}
c := newBinFlatChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
for _, test := range cases {
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
err := c.CheckValidDataType(fieldSchema)
err := c.CheckValidDataType("BINFLAT", fieldSchema)
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -2,13 +2,15 @@ package indexparamcheck
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
type binIVFFlatChecker struct {
binaryVectorBaseChecker
}
func (c binIVFFlatChecker) StaticCheck(params map[string]string) error {
func (c binIVFFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
if !CheckStrByValues(params, Metric, BinIvfMetrics) {
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinIvfMetrics)
}
@ -20,12 +22,12 @@ func (c binIVFFlatChecker) StaticCheck(params map[string]string) error {
return nil
}
func (c binIVFFlatChecker) CheckTrain(params map[string]string) error {
if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil {
func (c binIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.binaryVectorBaseChecker.CheckTrain(dataType, params); err != nil {
return err
}
return c.StaticCheck(params)
return c.StaticCheck(schemapb.DataType_BinaryVector, params)
}
func newBinIVFFlatChecker() IndexChecker {

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -115,9 +116,10 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
{p7, false},
}
c := newBinIVFFlatChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT")
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "BIN_IVF_FLAT"
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -185,10 +187,10 @@ func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) {
},
}
c := newBinIVFFlatChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT")
for _, test := range cases {
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
err := c.CheckValidDataType(fieldSchema)
err := c.CheckValidDataType("BIN_IVF_FLAT", fieldSchema)
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -19,22 +19,22 @@ func (c binaryVectorBaseChecker) staticCheck(params map[string]string) error {
return nil
}
func (c binaryVectorBaseChecker) CheckTrain(params map[string]string) error {
if err := c.baseChecker.CheckTrain(params); err != nil {
func (c binaryVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.baseChecker.CheckTrain(dataType, params); err != nil {
return err
}
return c.staticCheck(params)
}
func (c binaryVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c binaryVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
if field.GetDataType() != schemapb.DataType_BinaryVector {
return fmt.Errorf("binary vector is only supported")
}
return nil
}
func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType)
}

View File

@ -67,10 +67,10 @@ func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) {
},
}
c := newBinaryVectorBaseChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
for _, test := range cases {
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
err := c.CheckValidDataType(fieldSchema)
err := c.CheckValidDataType("BINFLAT", fieldSchema)
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -0,0 +1,33 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_BitmapIndexChecker(t *testing.T) {
c := newBITMAPChecker()
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true}))
}

View File

@ -11,11 +11,11 @@ type BITMAPChecker struct {
scalarIndexChecker
}
func (c *BITMAPChecker) CheckTrain(params map[string]string) error {
return c.scalarIndexChecker.CheckTrain(params)
func (c *BITMAPChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
return c.scalarIndexChecker.CheckTrain(dataType, params)
}
func (c *BITMAPChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c *BITMAPChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
if field.IsPrimaryKey {
return fmt.Errorf("create bitmap index on primary key not supported")
}

View File

@ -3,6 +3,8 @@ package indexparamcheck
import (
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
// diskannChecker checks if an diskann index can be built.
@ -10,8 +12,8 @@ type cagraChecker struct {
floatVectorBaseChecker
}
func (c *cagraChecker) CheckTrain(params map[string]string) error {
err := c.baseChecker.CheckTrain(params)
func (c *cagraChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
err := c.baseChecker.CheckTrain(dataType, params)
if err != nil {
return err
}
@ -54,7 +56,7 @@ func (c *cagraChecker) CheckTrain(params map[string]string) error {
return nil
}
func (c cagraChecker) StaticCheck(params map[string]string) error {
func (c cagraChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
return c.staticCheck(params)
}

View File

@ -6,6 +6,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -101,9 +103,13 @@ func Test_cagraChecker_CheckTrain(t *testing.T) {
{p14, false},
}
c := newCagraChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_CAGRA")
if c == nil {
log.Error("can not get index checker instance, please enable GPU and rerun it")
return
}
for _, test := range cases {
err := c.CheckTrain(test.params)
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -20,6 +20,8 @@ import (
"sync"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
)
type IndexCheckerMgr interface {
@ -34,36 +36,19 @@ type indexCheckerMgrImpl struct {
func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, error) {
mgr.once.Do(mgr.registerIndexChecker)
// Unify the vector index checker
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) {
return mgr.checkers[IndexVector], nil
}
adapter, ok := mgr.checkers[indexType]
if ok {
return adapter, nil
}
return nil, errors.New("Can not find conf adapter: " + indexType)
return nil, errors.New("Can not find index: " + indexType + " , please check")
}
func (mgr *indexCheckerMgrImpl) registerIndexChecker() {
mgr.checkers[IndexRaftIvfFlat] = newRaftIVFFlatChecker()
mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker()
mgr.checkers[IndexRaftCagra] = newCagraChecker()
mgr.checkers[IndexRaftBruteForce] = newRaftBruteForceChecker()
mgr.checkers[IndexFaissIDMap] = newFlatChecker()
mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker()
mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker()
mgr.checkers[IndexScaNN] = newScaNNChecker()
mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker()
mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker()
mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker()
mgr.checkers[IndexHNSW] = newHnswChecker()
mgr.checkers[IndexDISKANN] = newDiskannChecker()
mgr.checkers[IndexSparseInverted] = newSparseInvertedIndexChecker()
mgr.checkers[IndexFaissHNSW] = newFloatVectorBaseChecker()
mgr.checkers[IndexFaissHNSWPQ] = newFloatVectorBaseChecker()
mgr.checkers[IndexFaissHNSWSQ] = newFloatVectorBaseChecker()
mgr.checkers[IndexFaissHNSWPRQ] = newFloatVectorBaseChecker()
// WAND doesn't have more index params than sparse inverted index, thus
// using the same checker.
mgr.checkers[IndexSparseWand] = newSparseInvertedIndexChecker()
mgr.checkers[IndexVector] = newVecIndexChecker()
mgr.checkers[IndexINVERTED] = newINVERTEDChecker()
mgr.checkers[IndexSTLSORT] = newSTLSORTChecker()
mgr.checkers["Asceneding"] = newSTLSORTChecker()

View File

@ -29,52 +29,52 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) {
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, adapter)
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
adapter, err = adapterMgr.GetChecker("FLAT")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*flatChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
adapter, err = adapterMgr.GetChecker("IVF_FLAT")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*ivfBaseChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexScaNN)
adapter, err = adapterMgr.GetChecker("SCANN")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*scaNNChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
adapter, err = adapterMgr.GetChecker("IVF_PQ")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*ivfPQChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
adapter, err = adapterMgr.GetChecker("IVF_SQ8")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*ivfSQChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
adapter, err = adapterMgr.GetChecker("BIN_FLAT")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*binFlatChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*binIVFFlatChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexHNSW)
adapter, err = adapterMgr.GetChecker("HNSW")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*hnswChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
}
@ -89,52 +89,52 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) {
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, adapter)
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
adapter, err = adapterMgr.GetChecker("FLAT")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*flatChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
adapter, err = adapterMgr.GetChecker("IVF_FLAT")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*ivfBaseChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexScaNN)
adapter, err = adapterMgr.GetChecker("SCANN")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*scaNNChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
adapter, err = adapterMgr.GetChecker("IVF_PQ")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*ivfPQChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
adapter, err = adapterMgr.GetChecker("IVF_SQ8")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*ivfSQChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
adapter, err = adapterMgr.GetChecker("BIN_FLAT")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*binFlatChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*binIVFFlatChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetChecker(IndexHNSW)
adapter, err = adapterMgr.GetChecker("HNSW")
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*hnswChecker)
_, ok = adapter.(*vecIndexChecker)
assert.Equal(t, true, ok)
}
@ -146,7 +146,7 @@ func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
adapter, err := mgr.GetChecker(IndexHNSW)
adapter, err := mgr.GetChecker("HNSW")
assert.NoError(t, err)
assert.NotNil(t, adapter)
}()

View File

@ -65,7 +65,7 @@ var (
CagraBuildAlgoTypes = []string{CargaBuildAlgoIVFPQ, CargaBuildAlgoNNDESCENT}
supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
SparseMetrics = []string{metric.IP} // const
SparseMetrics = []string{metric.IP, metric.BM25} // const
)
const (

View File

@ -1,11 +1,13 @@
package indexparamcheck
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
// diskannChecker checks if an diskann index can be built.
type diskannChecker struct {
floatVectorBaseChecker
}
func (c diskannChecker) StaticCheck(params map[string]string) error {
func (c diskannChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
return c.staticCheck(params)
}

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -72,9 +73,10 @@ func Test_diskannChecker_CheckTrain(t *testing.T) {
{p7, false},
}
c := newDiskannChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN")
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "DISKANN"
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -142,9 +144,9 @@ func Test_diskannChecker_CheckValidDataType(t *testing.T) {
},
}
c := newDiskannChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN")
for _, test := range cases {
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
err := c.CheckValidDataType("DISKANN", &schemapb.FieldSchema{DataType: test.dType})
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -1,10 +1,12 @@
package indexparamcheck
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
type flatChecker struct {
floatVectorBaseChecker
}
func (c flatChecker) StaticCheck(m map[string]string) error {
func (c flatChecker) StaticCheck(dataType schemapb.DataType, m map[string]string) error {
return c.staticCheck(m)
}

View File

@ -6,6 +6,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -52,9 +54,10 @@ func Test_flatChecker_CheckTrain(t *testing.T) {
{p7, false},
}
c := newFlatChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT")
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "FLAT"
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -89,9 +92,10 @@ func Test_flatChecker_StaticCheck(t *testing.T) {
},
}
c := newFlatChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT")
for _, test := range cases {
err := c.StaticCheck(test.params)
test.params[common.IndexTypeKey] = "FLAT"
err := c.StaticCheck(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -20,22 +20,22 @@ func (c floatVectorBaseChecker) staticCheck(params map[string]string) error {
return nil
}
func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error {
if err := c.baseChecker.CheckTrain(params); err != nil {
func (c floatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.baseChecker.CheckTrain(dataType, params); err != nil {
return err
}
return c.staticCheck(params)
}
func (c floatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c floatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
if !typeutil.IsDenseFloatVectorType(field.GetDataType()) {
return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector")
}
return nil
}
func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
}

View File

@ -63,13 +63,13 @@ func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) {
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
errIsNil: true,
},
}
c := newFloatVectorBaseChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
for _, test := range cases {
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType})
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -12,7 +12,7 @@ type hnswChecker struct {
baseChecker
}
func (c hnswChecker) StaticCheck(params map[string]string) error {
func (c hnswChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction)
}
@ -25,21 +25,21 @@ func (c hnswChecker) StaticCheck(params map[string]string) error {
return nil
}
func (c hnswChecker) CheckTrain(params map[string]string) error {
if err := c.StaticCheck(params); err != nil {
func (c hnswChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.StaticCheck(dataType, params); err != nil {
return err
}
return c.baseChecker.CheckTrain(params)
return c.baseChecker.CheckTrain(dataType, params)
}
func (c hnswChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c hnswChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
if !typeutil.IsVectorType(field.GetDataType()) {
return fmt.Errorf("can't build hnsw in not vector type")
}
return nil
}
func (c hnswChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
func (c hnswChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
if typeutil.IsDenseFloatVectorType(dType) {
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
} else if typeutil.IsSparseFloatVectorType(dType) {

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -88,13 +89,19 @@ func Test_hnswChecker_CheckTrain(t *testing.T) {
{p3, true},
{p4, true},
{p5, true},
{p6, false},
{p7, false},
{p6, true},
{p7, true},
}
c := newHnswChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "HNSW"
var err error
if CheckStrByValues(test.params, common.MetricTypeKey, BinaryVectorMetrics) {
err = c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
} else {
err = c.CheckTrain(schemapb.DataType_FloatVector, test.params)
}
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -162,9 +169,9 @@ func Test_hnswChecker_CheckValidDataType(t *testing.T) {
},
}
c := newHnswChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
for _, test := range cases {
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType})
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -200,14 +207,14 @@ func Test_hnswChecker_SetDefaultMetricType(t *testing.T) {
},
}
c := newHnswChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
for _, test := range cases {
p := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
}
c.SetDefaultMetricTypeIfNotExist(p, test.dType)
c.SetDefaultMetricTypeIfNotExist(test.dType, p)
assert.Equal(t, p[Metric], test.metricType)
}
}

View File

@ -0,0 +1,37 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_HybridIndexChecker(t *testing.T) {
c := newHYBRIDChecker()
assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{"bitmap_cardinality_limit": "100"}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{}))
assert.Error(t, c.CheckTrain(schemapb.DataType_Float, map[string]string{"bitmap_cardinality_limit": "0"}))
assert.Error(t, c.CheckTrain(schemapb.DataType_Double, map[string]string{"bitmap_cardinality_limit": "2000"}))
}

View File

@ -12,15 +12,15 @@ type HYBRIDChecker struct {
scalarIndexChecker
}
func (c *HYBRIDChecker) CheckTrain(params map[string]string) error {
func (c *HYBRIDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if !CheckIntByRange(params, common.BitmapCardinalityLimitKey, 1, MaxBitmapCardinalityLimit) {
return fmt.Errorf("failed to check bitmap cardinality limit, should be larger than 0 and smaller than %d",
MaxBitmapCardinalityLimit)
}
return c.scalarIndexChecker.CheckTrain(params)
return c.scalarIndexChecker.CheckTrain(dataType, params)
}
func (c *HYBRIDChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c *HYBRIDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
mainType := field.GetDataType()
elemType := field.GetElementType()
if !typeutil.IsBoolType(mainType) && !typeutil.IsIntegerType(mainType) &&

View File

@ -21,8 +21,8 @@ import (
)
type IndexChecker interface {
CheckTrain(map[string]string) error
CheckValidDataType(field *schemapb.FieldSchema) error
SetDefaultMetricTypeIfNotExist(map[string]string, schemapb.DataType)
StaticCheck(map[string]string) error
CheckTrain(schemapb.DataType, map[string]string) error
CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error
SetDefaultMetricTypeIfNotExist(schemapb.DataType, map[string]string)
StaticCheck(schemapb.DataType, map[string]string) error
}

View File

@ -15,6 +15,7 @@ import (
"fmt"
"strconv"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common"
)
@ -23,31 +24,7 @@ type IndexType = string
// IndexType definitions
const (
// vector index
IndexGpuBF IndexType = "GPU_BRUTE_FORCE"
IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT"
IndexRaftIvfPQ IndexType = "GPU_IVF_PQ"
IndexRaftCagra IndexType = "GPU_CAGRA"
IndexRaftBruteForce IndexType = "GPU_BRUTE_FORCE"
IndexFaissIDMap IndexType = "FLAT" // no index is built.
IndexFaissIvfFlat IndexType = "IVF_FLAT"
IndexFaissIvfPQ IndexType = "IVF_PQ"
IndexScaNN IndexType = "SCANN"
IndexFaissIvfSQ8 IndexType = "IVF_SQ8"
IndexFaissBinIDMap IndexType = "BIN_FLAT"
IndexFaissBinIvfFlat IndexType = "BIN_IVF_FLAT"
IndexHNSW IndexType = "HNSW"
IndexDISKANN IndexType = "DISKANN"
IndexSparseInverted IndexType = "SPARSE_INVERTED_INDEX"
IndexSparseWand IndexType = "SPARSE_WAND"
// For temporary use, will be removed in the future.
// 1. All Index related param check will be moved to Knowhere recently.
// 2. FAISS_HNSW_xxx will be rename to HNSW_xxx after QA test. We keep the original name for comparison purpose.
// TODO: @liliu-z @foxspy
IndexFaissHNSW IndexType = "FAISS_HNSW_FLAT"
IndexFaissHNSWPQ IndexType = "FAISS_HNSW_PQ"
IndexFaissHNSWSQ IndexType = "FAISS_HNSW_SQ"
IndexFaissHNSWPRQ IndexType = "FAISS_HNSW_PRQ"
IndexVector IndexType = "VECINDEX"
// scalar index
IndexSTLSORT IndexType = "STL_SORT"
@ -66,28 +43,12 @@ func IsScalarIndexType(indexType IndexType) bool {
}
func IsGpuIndex(indexType IndexType) bool {
return indexType == IndexGpuBF ||
indexType == IndexRaftIvfFlat ||
indexType == IndexRaftIvfPQ ||
indexType == IndexRaftCagra
return vecindexmgr.GetVecIndexMgrInstance().IsGPUVecIndex(indexType)
}
// IsVectorMmapIndex check if the vector index can be mmaped
func IsVectorMmapIndex(indexType IndexType) bool {
return indexType == IndexFaissIDMap ||
indexType == IndexFaissIvfFlat ||
indexType == IndexFaissIvfPQ ||
indexType == IndexFaissIvfSQ8 ||
indexType == IndexFaissBinIDMap ||
indexType == IndexFaissBinIvfFlat ||
indexType == IndexHNSW ||
indexType == IndexFaissHNSW ||
indexType == IndexFaissHNSWPQ ||
indexType == IndexFaissHNSWSQ ||
indexType == IndexFaissHNSWPRQ ||
indexType == IndexScaNN ||
indexType == IndexSparseInverted ||
indexType == IndexSparseWand
return vecindexmgr.GetVecIndexMgrInstance().IsMMapSupported(indexType)
}
func IsOffsetCacheSupported(indexType IndexType) bool {
@ -95,7 +56,7 @@ func IsOffsetCacheSupported(indexType IndexType) bool {
}
func IsDiskIndex(indexType IndexType) bool {
return indexType == IndexDISKANN
return vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType)
}
func IsScalarMmapIndex(indexType IndexType) bool {

View File

@ -12,11 +12,11 @@ type INVERTEDChecker struct {
scalarIndexChecker
}
func (c *INVERTEDChecker) CheckTrain(params map[string]string) error {
return c.scalarIndexChecker.CheckTrain(params)
func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
return c.scalarIndexChecker.CheckTrain(dataType, params)
}
func (c *INVERTEDChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c *INVERTEDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
dType := field.GetDataType()
if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) &&
!typeutil.IsArrayType(dType) {

View File

@ -0,0 +1,25 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_INVERTEDIndexChecker(t *testing.T) {
c := newINVERTEDChecker()
assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{}))
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Array}))
assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}))
}

View File

@ -0,0 +1,28 @@
package indexparamcheck
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
type ivfBaseChecker struct {
floatVectorBaseChecker
}
func (c ivfBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
return errOutOfRange(NLIST, MinNList, MaxNList)
}
// skip check number of rows
return c.floatVectorBaseChecker.staticCheck(params)
}
func (c ivfBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.StaticCheck(dataType, params); err != nil {
return err
}
return c.floatVectorBaseChecker.CheckTrain(dataType, params)
}
func newIVFBaseChecker() IndexChecker {
return &ivfBaseChecker{}
}

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -70,9 +71,10 @@ func Test_ivfBaseChecker_CheckTrain(t *testing.T) {
{p7, false},
}
c := newIVFBaseChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT")
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "IVF_FLAT"
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -140,9 +142,9 @@ func Test_ivfBaseChecker_CheckValidDataType(t *testing.T) {
},
}
c := newIVFBaseChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT")
for _, test := range cases {
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
err := c.CheckValidDataType("IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType})
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -3,6 +3,8 @@ package indexparamcheck
import (
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
// ivfPQChecker checks if a IVF_PQ index can be built.
@ -11,8 +13,8 @@ type ivfPQChecker struct {
}
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
func (c *ivfPQChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
func (c *ivfPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
return err
}

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -141,9 +142,11 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
{p7, false},
}
// c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ")
c := newIVFPQChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "IVF_PQ"
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -211,9 +214,9 @@ func Test_ivfPQChecker_CheckValidDataType(t *testing.T) {
},
}
c := newIVFPQChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ")
for _, test := range cases {
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
err := c.CheckValidDataType("IVF_PQ", &schemapb.FieldSchema{DataType: test.dType})
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -2,6 +2,8 @@ package indexparamcheck
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
// ivfSQChecker checks if a IVF_SQ index can be built.
@ -22,11 +24,11 @@ func (c *ivfSQChecker) checkNBits(params map[string]string) error {
}
// CheckTrain returns true if the index can be built with the specific index parameters.
func (c *ivfSQChecker) CheckTrain(params map[string]string) error {
func (c *ivfSQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.checkNBits(params); err != nil {
return err
}
return c.ivfBaseChecker.CheckTrain(params)
return c.ivfBaseChecker.CheckTrain(dataType, params)
}
func newIVFSQChecker() IndexChecker {

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -78,7 +79,6 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
}{
{validParams, true},
{validParamsWithNBits, true},
{paramsWithInvalidNBits, false},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{p1, true},
@ -90,9 +90,10 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
{p7, false},
}
c := newIVFSQChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ")
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "IVF_SQ"
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -160,9 +161,9 @@ func Test_ivfSQChecker_CheckValidDataType(t *testing.T) {
},
}
c := newIVFSQChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ")
for _, test := range cases {
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
err := c.CheckValidDataType("IVF_SQ8", &schemapb.FieldSchema{DataType: test.dType})
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -1,14 +1,18 @@
package indexparamcheck
import "fmt"
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
type raftBruteForceChecker struct {
floatVectorBaseChecker
}
// raftBrustForceChecker checks if a Brute_Force index can be built.
func (c raftBruteForceChecker) CheckTrain(params map[string]string) error {
if err := c.floatVectorBaseChecker.CheckTrain(params); err != nil {
func (c raftBruteForceChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.floatVectorBaseChecker.CheckTrain(dataType, params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, RaftMetrics) {

View File

@ -6,6 +6,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -52,9 +55,14 @@ func Test_raftbfChecker_CheckTrain(t *testing.T) {
{p7, false},
}
c := newRaftBruteForceChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_BRUTE_FORCE")
if c == nil {
log.Error("can not get index checker instance, please enable GPU and rerun it")
return
}
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "GPU_BRUTE_FORCE"
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -1,6 +1,10 @@
package indexparamcheck
import "fmt"
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
// raftIVFChecker checks if a RAFT_IVF_Flat index can be built.
type raftIVFFlatChecker struct {
@ -8,8 +12,8 @@ type raftIVFFlatChecker struct {
}
// CheckTrain checks if ivf-flat index can be built with the specific index parameters.
func (c *raftIVFFlatChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
func (c *raftIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, RaftMetrics) {

View File

@ -7,6 +7,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -84,9 +86,14 @@ func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) {
{p9, false},
}
c := newRaftIVFFlatChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT")
if c == nil {
log.Error("can not get index checker instance, please enable GPU and rerun it")
return
}
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "GPU_IVF_FLAT"
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -154,9 +161,13 @@ func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) {
},
}
c := newRaftIVFFlatChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT")
if c == nil {
log.Error("can not get index checker instance, please enable GPU and rerun it")
return
}
for _, test := range cases {
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
err := c.CheckValidDataType("GPU_IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType})
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -3,6 +3,8 @@ package indexparamcheck
import (
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
// raftIVFPQChecker checks if a RAFT_IVF_PQ index can be built.
@ -11,8 +13,8 @@ type raftIVFPQChecker struct {
}
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
func (c *raftIVFPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, RaftMetrics) {

View File

@ -7,6 +7,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -144,9 +146,14 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
{p9, false},
}
c := newRaftIVFPQChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ")
if c == nil {
log.Error("can not get index checker instance, please enable GPU and rerun it")
return
}
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "GPU_IVF_PQ"
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -214,9 +221,13 @@ func Test_raftIVFPQChecker_CheckValidDataType(t *testing.T) {
},
}
c := newRaftIVFPQChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ")
if c == nil {
log.Error("can not get index checker instance, please enable GPU and rerun it")
return
}
for _, test := range cases {
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
err := c.CheckValidDataType("GPU_IVF_PQ", &schemapb.FieldSchema{DataType: test.dType})
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -0,0 +1,11 @@
package indexparamcheck
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
type scalarIndexChecker struct {
baseChecker
}
func (c scalarIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
return nil
}

View File

@ -4,9 +4,11 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func TestCheckIndexValid(t *testing.T) {
scalarIndexChecker := &scalarIndexChecker{}
assert.NoError(t, scalarIndexChecker.CheckTrain(map[string]string{}))
assert.NoError(t, scalarIndexChecker.CheckTrain(schemapb.DataType_Bool, map[string]string{}))
}

View File

@ -3,6 +3,8 @@ package indexparamcheck
import (
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
// scaNNChecker checks if a SCANN index can be built.
@ -11,8 +13,8 @@ type scaNNChecker struct {
}
// CheckTrain checks if SCANN index can be built with the specific index parameters.
func (c *scaNNChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
func (c *scaNNChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
return err
}

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
@ -87,9 +88,10 @@ func Test_scaNNChecker_CheckTrain(t *testing.T) {
{p7, false},
}
c := newScaNNChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("SCANN")
for _, test := range cases {
err := c.CheckTrain(test.params)
test.params[common.IndexTypeKey] = "SCANN"
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
@ -159,7 +161,7 @@ func Test_scaNNChecker_CheckValidDataType(t *testing.T) {
c := newScaNNChecker()
for _, test := range cases {
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
err := c.CheckValidDataType("SCANN", &schemapb.FieldSchema{DataType: test.dType})
if test.errIsNil {
assert.NoError(t, err)
} else {

View File

@ -12,7 +12,7 @@ import (
// sparse vector don't check for dim, but baseChecker does, thus not including baseChecker
type sparseFloatVectorBaseChecker struct{}
func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) error {
func (c sparseFloatVectorBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
if !CheckStrByValues(params, Metric, SparseMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics)
}
@ -20,7 +20,7 @@ func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) erro
return nil
}
func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error {
func (c sparseFloatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
dropRatioBuildStr, exist := params[SparseDropRatioBuild]
if exist {
dropRatioBuild, err := strconv.ParseFloat(dropRatioBuildStr, 64)
@ -48,14 +48,14 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error
return nil
}
func (c sparseFloatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c sparseFloatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
if !typeutil.IsSparseFloatVectorType(field.GetDataType()) {
return fmt.Errorf("only sparse float vector is supported for the specified index tpye")
}
return nil
}
func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType)
}

View File

@ -6,84 +6,95 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
)
func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) {
validParams := map[string]string{
Metric: "IP",
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
Metric: "IP",
}
invalidParams := map[string]string{
Metric: "L2",
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
Metric: "L2",
}
c := newSparseFloatVectorBaseChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX")
t.Run("valid metric", func(t *testing.T) {
err := c.StaticCheck(validParams)
err := c.StaticCheck(schemapb.DataType_SparseFloatVector, validParams)
assert.NoError(t, err)
})
t.Run("invalid metric", func(t *testing.T) {
err := c.StaticCheck(invalidParams)
err := c.StaticCheck(schemapb.DataType_SparseFloatVector, invalidParams)
assert.Error(t, err)
})
}
func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
Metric: "IP",
SparseDropRatioBuild: "0.5",
BM25K1: "1.5",
BM25B: "0.5",
}
invalidDropRatio := map[string]string{
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
Metric: "IP",
SparseDropRatioBuild: "1.5",
}
invalidBM25K1 := map[string]string{
BM25K1: "3.5",
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
Metric: "IP",
BM25K1: "3.5",
}
invalidBM25B := map[string]string{
BM25B: "1.5",
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
Metric: "IP",
BM25B: "1.5",
}
c := newSparseFloatVectorBaseChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX")
t.Run("valid params", func(t *testing.T) {
err := c.CheckTrain(validParams)
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, validParams)
assert.NoError(t, err)
})
t.Run("invalid drop ratio", func(t *testing.T) {
err := c.CheckTrain(invalidDropRatio)
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidDropRatio)
assert.Error(t, err)
})
t.Run("invalid BM25K1", func(t *testing.T) {
err := c.CheckTrain(invalidBM25K1)
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25K1)
assert.Error(t, err)
})
t.Run("invalid BM25B", func(t *testing.T) {
err := c.CheckTrain(invalidBM25B)
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25B)
assert.Error(t, err)
})
}
func Test_sparseFloatVectorBaseChecker_CheckValidDataType(t *testing.T) {
c := newSparseFloatVectorBaseChecker()
c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX")
t.Run("valid data type", func(t *testing.T) {
field := &schemapb.FieldSchema{DataType: schemapb.DataType_SparseFloatVector}
err := c.CheckValidDataType(field)
err := c.CheckValidDataType("SPARSE_WAND", field)
assert.NoError(t, err)
})
t.Run("invalid data type", func(t *testing.T) {
field := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
err := c.CheckValidDataType(field)
err := c.CheckValidDataType("SPARSE_WAND", field)
assert.Error(t, err)
})
}

View File

@ -12,11 +12,11 @@ type STLSORTChecker struct {
scalarIndexChecker
}
func (c *STLSORTChecker) CheckTrain(params map[string]string) error {
return c.scalarIndexChecker.CheckTrain(params)
func (c *STLSORTChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
return c.scalarIndexChecker.CheckTrain(dataType, params)
}
func (c *STLSORTChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c *STLSORTChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
if !typeutil.IsArithmetic(field.GetDataType()) {
return fmt.Errorf("STL_SORT are only supported on numeric field")
}

View File

@ -0,0 +1,22 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_STLSORTIndexChecker(t *testing.T) {
c := newSTLSORTChecker()
assert.NoError(t, c.CheckTrain(schemapb.DataType_Int64, map[string]string{}))
assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
}

View File

@ -12,11 +12,11 @@ type TRIEChecker struct {
scalarIndexChecker
}
func (c *TRIEChecker) CheckTrain(params map[string]string) error {
return c.scalarIndexChecker.CheckTrain(params)
func (c *TRIEChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
return c.scalarIndexChecker.CheckTrain(dataType, params)
}
func (c *TRIEChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
func (c *TRIEChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
if !typeutil.IsStringType(field.GetDataType()) {
return fmt.Errorf("TRIE are only supported on varchar field")
}

View File

@ -0,0 +1,23 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_TrieIndexChecker(t *testing.T) {
c := newTRIEChecker()
assert.NoError(t, c.CheckTrain(schemapb.DataType_VarChar, map[string]string{}))
assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
}

View File

@ -20,7 +20,10 @@ import (
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
// CheckIntByRange check if the data corresponding to the key is in the range of [min, max].
@ -69,3 +72,30 @@ func setDefaultIfNotExist(params map[string]string, key string, defaultValue str
params[key] = defaultValue
}
}
func CheckAutoIndexHelper(key string, m map[string]string, dtype schemapb.DataType) {
indexType, ok := m[common.IndexTypeKey]
if !ok {
panic(fmt.Sprintf("%s invalid, index type not found", key))
}
checker, err := GetIndexCheckerMgrInstance().GetChecker(indexType)
if err != nil {
panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType))
}
if err := checker.StaticCheck(dtype, m); err != nil {
panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error()))
}
}
func CheckAutoIndexConfig() {
autoIndexCfg := &paramtable.Get().AutoIndexConfig
CheckAutoIndexHelper(autoIndexCfg.IndexParams.Key, autoIndexCfg.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
CheckAutoIndexHelper(autoIndexCfg.BinaryIndexParams.Key, autoIndexCfg.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
CheckAutoIndexHelper(autoIndexCfg.SparseIndexParams.Key, autoIndexCfg.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector)
}
func ValidateParamTable() {
CheckAutoIndexConfig()
}

View File

@ -0,0 +1,269 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func Test_CheckIntByRange(t *testing.T) {
params := map[string]string{
"1": strconv.Itoa(1),
"2": strconv.Itoa(2),
"3": strconv.Itoa(3),
"s1": "s1",
"s2": "s2",
"s3": "s3",
}
cases := []struct {
params map[string]string
key string
min int
max int
want bool
}{
{params, "1", 0, 4, true},
{params, "2", 0, 4, true},
{params, "3", 0, 4, true},
{params, "1", 4, 5, false},
{params, "2", 4, 5, false},
{params, "3", 4, 5, false},
{params, "4", 0, 4, false},
{params, "5", 0, 4, false},
{params, "6", 0, 4, false},
{params, "s1", 0, 4, false},
{params, "s2", 0, 4, false},
{params, "s3", 0, 4, false},
{params, "s4", 0, 4, false},
{params, "s5", 0, 4, false},
{params, "s6", 0, 4, false},
}
for _, test := range cases {
if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want {
t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want)
}
}
}
func Test_CheckStrByValues(t *testing.T) {
params := map[string]string{
"1": strconv.Itoa(1),
"2": strconv.Itoa(2),
"3": strconv.Itoa(3),
}
cases := []struct {
params map[string]string
key string
container []string
want bool
}{
{params, "1", []string{"1", "2", "3"}, true},
{params, "2", []string{"1", "2", "3"}, true},
{params, "3", []string{"1", "2", "3"}, true},
{params, "1", []string{"4", "5", "6"}, false},
{params, "2", []string{"4", "5", "6"}, false},
{params, "3", []string{"4", "5", "6"}, false},
{params, "1", []string{}, false},
{params, "2", []string{}, false},
{params, "3", []string{}, false},
{params, "4", []string{"1", "2", "3"}, false},
{params, "5", []string{"1", "2", "3"}, false},
{params, "6", []string{"1", "2", "3"}, false},
{params, "4", []string{"4", "5", "6"}, false},
{params, "5", []string{"4", "5", "6"}, false},
{params, "6", []string{"4", "5", "6"}, false},
{params, "4", []string{}, false},
{params, "5", []string{}, false},
{params, "6", []string{}, false},
}
for _, test := range cases {
if got := CheckStrByValues(test.params, test.key, test.container); got != test.want {
t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want)
}
}
}
func Test_CheckAutoIndex(t *testing.T) {
t.Run("index type not found", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"M": 30}`)
p := &paramtable.AutoIndexConfig{
IndexParams: paramtable.ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.Panics(t, func() {
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
})
})
t.Run("unsupported index type", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"index_type": "not supported"}`)
p := &paramtable.AutoIndexConfig{
IndexParams: paramtable.ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.Panics(t, func() {
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
})
})
t.Run("normal case, hnsw", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW"}`)
p := &paramtable.AutoIndexConfig{
IndexParams: paramtable.ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
assert.NotPanics(t, func() {
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, metric.COSINE, metricType)
})
t.Run("normal case, binary vector", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.binary.build", `{"nlist": 1024, "index_type": "BIN_IVF_FLAT"}`)
p := &paramtable.AutoIndexConfig{
BinaryIndexParams: paramtable.ParamItem{
Key: "autoIndex.params.binary.build",
},
}
p.BinaryIndexParams.Init(mgr)
p.SetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
assert.NotPanics(t, func() {
CheckAutoIndexHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
})
metricType, exist := p.BinaryIndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, metric.HAMMING, metricType)
})
t.Run("normal case, sparse vector", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.sparse.build", `{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}`)
p := &paramtable.AutoIndexConfig{
SparseIndexParams: paramtable.ParamItem{
Key: "autoIndex.params.sparse.build",
},
}
p.SparseIndexParams.Init(mgr)
p.SetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
assert.NotPanics(t, func() {
CheckAutoIndexHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector)
})
metricType, exist := p.SparseIndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, metric.IP, metricType)
})
t.Run("normal case, ivf flat", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
p := &paramtable.AutoIndexConfig{
IndexParams: paramtable.ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
assert.NotPanics(t, func() {
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, metric.COSINE, metricType)
})
t.Run("normal case, ivf flat", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
p := &paramtable.AutoIndexConfig{
IndexParams: paramtable.ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
assert.NotPanics(t, func() {
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, metric.COSINE, metricType)
})
t.Run("normal case, diskann", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"index_type": "DISKANN"}`)
p := &paramtable.AutoIndexConfig{
IndexParams: paramtable.ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
assert.NotPanics(t, func() {
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, metric.COSINE, metricType)
})
t.Run("normal case, bin flat", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"index_type": "BIN_FLAT"}`)
p := &paramtable.AutoIndexConfig{
IndexParams: paramtable.ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
assert.NotPanics(t, func() {
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, metric.HAMMING, metricType)
})
t.Run("normal case, bin ivf flat", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "BIN_IVF_FLAT"}`)
p := &paramtable.AutoIndexConfig{
IndexParams: paramtable.ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
assert.NotPanics(t, func() {
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, metric.HAMMING, metricType)
})
}

View File

@ -0,0 +1,112 @@
package indexparamcheck
/*
#cgo pkg-config: milvus_core
#include <stdlib.h> // free
#include "segcore/vector_index_c.h"
*/
import "C"
import (
"fmt"
"unsafe"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexcgopb"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type vecIndexChecker struct {
baseChecker
}
// HandleCStatus deals with the error returned from CGO
func HandleCStatus(status *C.CStatus) error {
if status.error_code == 0 {
return nil
}
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return fmt.Errorf("%s", errorMsg)
}
func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
if typeutil.IsDenseFloatVectorType(dataType) {
if !CheckStrByValues(params, Metric, FloatVectorMetrics) {
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], FloatVectorMetrics)
}
} else if typeutil.IsSparseFloatVectorType(dataType) {
if !CheckStrByValues(params, Metric, SparseMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics)
}
} else if typeutil.IsBinaryVectorType(dataType) {
if !CheckStrByValues(params, Metric, BinaryVectorMetrics) {
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinaryVectorMetrics)
}
}
indexType, exist := params[common.IndexTypeKey]
if !exist {
return fmt.Errorf("no indexType is specified")
}
if !vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) {
return fmt.Errorf("indexType %s is not supported", indexType)
}
protoIndexParams := &indexcgopb.IndexParams{
Params: make([]*commonpb.KeyValuePair, 0),
}
for key, value := range params {
protoIndexParams.Params = append(protoIndexParams.Params, &commonpb.KeyValuePair{Key: key, Value: value})
}
indexParamsBlob, err := proto.Marshal(protoIndexParams)
if err != nil {
return fmt.Errorf("failed to marshal index params: %s", err)
}
var status C.CStatus
cIndexType := C.CString(indexType)
cDataType := uint32(dataType)
status = C.ValidateIndexParams(cIndexType, cDataType, (*C.uint8_t)(unsafe.Pointer(&indexParamsBlob[0])), (C.uint64_t)(len(indexParamsBlob)))
C.free(unsafe.Pointer(cIndexType))
return HandleCStatus(&status)
}
func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
if err := c.StaticCheck(dataType, params); err != nil {
return err
}
return c.baseChecker.CheckTrain(dataType, params)
}
func (c vecIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
if !typeutil.IsVectorType(field.GetDataType()) {
return fmt.Errorf("index %s only supports vector data type", indexType)
}
if !vecindexmgr.GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType()) {
return fmt.Errorf("index %s do not support data type: %s", indexType, schemapb.DataType_name[int32(field.GetDataType())])
}
return nil
}
func (c vecIndexChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
paramtable.SetDefaultMetricTypeIfNotExist(dType, params)
}
func newVecIndexChecker() IndexChecker {
return &vecIndexChecker{}
}

View File

@ -0,0 +1,132 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func TestVecIndexChecker_StaticCheck(t *testing.T) {
checker := newVecIndexChecker()
tests := []struct {
name string
dataType schemapb.DataType
params map[string]string
wantErr bool
}{
{
name: "Valid IVF_FLAT index",
dataType: schemapb.DataType_FloatVector,
params: map[string]string{
"index_type": "IVF_FLAT",
"metric_type": "L2",
"nlist": "1024",
},
wantErr: false,
},
{
name: "Invalid index type",
dataType: schemapb.DataType_FloatVector,
params: map[string]string{
"index_type": "INVALID_INDEX",
},
wantErr: true,
},
{
name: "Missing index type",
dataType: schemapb.DataType_FloatVector,
params: map[string]string{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checker.StaticCheck(tt.dataType, tt.params)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestVecIndexChecker_CheckValidDataType(t *testing.T) {
checker := newVecIndexChecker()
tests := []struct {
name string
indexType IndexType
field *schemapb.FieldSchema
wantErr bool
}{
{
name: "Valid float vector",
indexType: "IVF_FLAT",
field: &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
},
wantErr: false,
},
{
name: "Invalid data type",
indexType: "IVF_FLAT",
field: &schemapb.FieldSchema{
DataType: schemapb.DataType_Int64,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checker.CheckValidDataType(tt.indexType, tt.field)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestVecIndexChecker_SetDefaultMetricTypeIfNotExist(t *testing.T) {
checker := newVecIndexChecker()
tests := []struct {
name string
dataType schemapb.DataType
params map[string]string
expectedType string
}{
{
name: "Float vector",
dataType: schemapb.DataType_FloatVector,
params: map[string]string{},
expectedType: FloatVectorDefaultMetricType,
},
{
name: "Binary vector",
dataType: schemapb.DataType_BinaryVector,
params: map[string]string{},
expectedType: BinaryVectorDefaultMetricType,
},
{
name: "Existing metric type",
dataType: schemapb.DataType_FloatVector,
params: map[string]string{"metric_type": "IP"},
expectedType: "IP",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
checker.SetDefaultMetricTypeIfNotExist(tt.dataType, tt.params)
assert.Equal(t, tt.expectedType, tt.params["metric_type"])
})
}
}

View File

@ -1,17 +0,0 @@
package indexparamcheck
type binFlatChecker struct {
binaryVectorBaseChecker
}
func (c binFlatChecker) CheckTrain(params map[string]string) error {
return c.binaryVectorBaseChecker.CheckTrain(params)
}
func (c binFlatChecker) StaticCheck(params map[string]string) error {
return c.staticCheck(params)
}
func newBinFlatChecker() IndexChecker {
return &binFlatChecker{}
}

View File

@ -1,33 +0,0 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_BitmapIndexChecker(t *testing.T) {
c := newBITMAPChecker()
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true}))
}

View File

@ -1,37 +0,0 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_HybridIndexChecker(t *testing.T) {
c := newHYBRIDChecker()
assert.NoError(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "100"}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
assert.Error(t, c.CheckTrain(map[string]string{}))
assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "0"}))
assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "2000"}))
}

View File

@ -1,25 +0,0 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_INVERTEDIndexChecker(t *testing.T) {
c := newINVERTEDChecker()
assert.NoError(t, c.CheckTrain(map[string]string{}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}))
}

View File

@ -1,26 +0,0 @@
package indexparamcheck
type ivfBaseChecker struct {
floatVectorBaseChecker
}
func (c ivfBaseChecker) StaticCheck(params map[string]string) error {
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
return errOutOfRange(NLIST, MinNList, MaxNList)
}
// skip check number of rows
return c.floatVectorBaseChecker.staticCheck(params)
}
func (c ivfBaseChecker) CheckTrain(params map[string]string) error {
if err := c.StaticCheck(params); err != nil {
return err
}
return c.floatVectorBaseChecker.CheckTrain(params)
}
func newIVFBaseChecker() IndexChecker {
return &ivfBaseChecker{}
}

View File

@ -1,9 +0,0 @@
package indexparamcheck
type scalarIndexChecker struct {
baseChecker
}
func (c scalarIndexChecker) CheckTrain(params map[string]string) error {
return nil
}

View File

@ -1,22 +0,0 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_STLSORTIndexChecker(t *testing.T) {
c := newSTLSORTChecker()
assert.NoError(t, c.CheckTrain(map[string]string{}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
}

View File

@ -1,23 +0,0 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_TrieIndexChecker(t *testing.T) {
c := newTRIEChecker()
assert.NoError(t, c.CheckTrain(map[string]string{}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
}

View File

@ -1,87 +0,0 @@
package indexparamcheck
import (
"strconv"
"testing"
)
func Test_CheckIntByRange(t *testing.T) {
params := map[string]string{
"1": strconv.Itoa(1),
"2": strconv.Itoa(2),
"3": strconv.Itoa(3),
"s1": "s1",
"s2": "s2",
"s3": "s3",
}
cases := []struct {
params map[string]string
key string
min int
max int
want bool
}{
{params, "1", 0, 4, true},
{params, "2", 0, 4, true},
{params, "3", 0, 4, true},
{params, "1", 4, 5, false},
{params, "2", 4, 5, false},
{params, "3", 4, 5, false},
{params, "4", 0, 4, false},
{params, "5", 0, 4, false},
{params, "6", 0, 4, false},
{params, "s1", 0, 4, false},
{params, "s2", 0, 4, false},
{params, "s3", 0, 4, false},
{params, "s4", 0, 4, false},
{params, "s5", 0, 4, false},
{params, "s6", 0, 4, false},
}
for _, test := range cases {
if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want {
t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want)
}
}
}
func Test_CheckStrByValues(t *testing.T) {
params := map[string]string{
"1": strconv.Itoa(1),
"2": strconv.Itoa(2),
"3": strconv.Itoa(3),
}
cases := []struct {
params map[string]string
key string
container []string
want bool
}{
{params, "1", []string{"1", "2", "3"}, true},
{params, "2", []string{"1", "2", "3"}, true},
{params, "3", []string{"1", "2", "3"}, true},
{params, "1", []string{"4", "5", "6"}, false},
{params, "2", []string{"4", "5", "6"}, false},
{params, "3", []string{"4", "5", "6"}, false},
{params, "1", []string{}, false},
{params, "2", []string{}, false},
{params, "3", []string{}, false},
{params, "4", []string{"1", "2", "3"}, false},
{params, "5", []string{"1", "2", "3"}, false},
{params, "6", []string{"1", "2", "3"}, false},
{params, "4", []string{"4", "5", "6"}, false},
{params, "5", []string{"4", "5", "6"}, false},
{params, "6", []string{"4", "5", "6"}, false},
{params, "4", []string{}, false},
{params, "5", []string{}, false},
{params, "6", []string{}, false},
}
for _, test := range cases {
if got := CheckStrByValues(test.params, test.key, test.container); got != test.want {
t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want)
}
}
}

View File

@ -24,12 +24,13 @@ import (
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// /////////////////////////////////////////////////////////////////////////////
// --- common ---
type autoIndexConfig struct {
type AutoIndexConfig struct {
Enable ParamItem `refreshable:"true"`
EnableOptimize ParamItem `refreshable:"true"`
EnableResultLimitCheck ParamItem `refreshable:"true"`
@ -60,7 +61,7 @@ const (
DefaultBitmapCardinalityLimit = 100
)
func (p *autoIndexConfig) init(base *BaseTable) {
func (p *AutoIndexConfig) init(base *BaseTable) {
p.Enable = ParamItem{
Key: "autoIndex.enable",
Version: "2.2.0",
@ -157,7 +158,7 @@ func (p *autoIndexConfig) init(base *BaseTable) {
}
p.AutoIndexTuningConfig.Init(base.mgr)
p.panicIfNotValidAndSetDefaultMetricType(base.mgr)
p.SetDefaultMetricType(base.mgr)
p.ScalarAutoIndexEnable = ParamItem{
Key: "scalarAutoIndex.enable",
@ -244,37 +245,47 @@ func (p *autoIndexConfig) init(base *BaseTable) {
p.ScalarBoolIndexType.Init(base.mgr)
}
func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricType(mgr *config.Manager) {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
// SetDefaultMetricType The config check logic has been moved to internal package; only set defulat metric here
func (p *AutoIndexConfig) SetDefaultMetricType(mgr *config.Manager) {
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
p.SetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
p.SetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
}
func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricTypeHelper(key string, m map[string]string, dtype schemapb.DataType, mgr *config.Manager) {
func setDefaultIfNotExist(params map[string]string, key string, defaultValue string) {
_, exist := params[key]
if !exist {
params[key] = defaultValue
}
}
const (
FloatVectorDefaultMetricType = metric.COSINE
SparseFloatVectorDefaultMetricType = metric.IP
BinaryVectorDefaultMetricType = metric.HAMMING
)
func SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
if typeutil.IsDenseFloatVectorType(dType) {
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
} else if typeutil.IsSparseFloatVectorType(dType) {
setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType)
} else if typeutil.IsBinaryVectorType(dType) {
setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType)
}
}
func (p *AutoIndexConfig) SetDefaultMetricTypeHelper(key string, m map[string]string, dtype schemapb.DataType, mgr *config.Manager) {
if m == nil {
panic(fmt.Sprintf("%s invalid, should be json format", key))
}
indexType, ok := m[common.IndexTypeKey]
if !ok {
panic(fmt.Sprintf("%s invalid, index type not found", key))
}
checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType)
if err != nil {
panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType))
}
checker.SetDefaultMetricTypeIfNotExist(m, dtype)
if err := checker.StaticCheck(m); err != nil {
panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error()))
}
SetDefaultMetricTypeIfNotExist(dtype, m)
p.reset(key, m, mgr)
}
func (p *autoIndexConfig) reset(key string, m map[string]string, mgr *config.Manager) {
func (p *AutoIndexConfig) reset(key string, m map[string]string, mgr *config.Manager) {
ret, err := funcutil.MapToJSON(m)
if err != nil {
panic(fmt.Sprintf("%s: convert to json failed, parameters invalid, error: %s", key, err.Error()))

View File

@ -26,7 +26,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
)
const (
@ -134,180 +133,16 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) {
t.Run("not in json format", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", "not in json format")
p := &autoIndexConfig{
p := &AutoIndexConfig{
IndexParams: ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.Panics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
})
})
t.Run("index type not found", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"M": 30}`)
p := &autoIndexConfig{
IndexParams: ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.Panics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
})
})
t.Run("unsupported index type", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"index_type": "not supported"}`)
p := &autoIndexConfig{
IndexParams: ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.Panics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
})
})
t.Run("normal case, hnsw", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW"}`)
p := &autoIndexConfig{
IndexParams: ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.NotPanics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
})
t.Run("normal case, binary vector", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.binary.build", `{"nlist": 1024, "index_type": "BIN_IVF_FLAT"}`)
p := &autoIndexConfig{
BinaryIndexParams: ParamItem{
Key: "autoIndex.params.binary.build",
},
}
p.BinaryIndexParams.Init(mgr)
assert.NotPanics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
})
metricType, exist := p.BinaryIndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType)
})
t.Run("normal case, sparse vector", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.sparse.build", `{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}`)
p := &autoIndexConfig{
SparseIndexParams: ParamItem{
Key: "autoIndex.params.sparse.build",
},
}
p.SparseIndexParams.Init(mgr)
assert.NotPanics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
})
metricType, exist := p.SparseIndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, indexparamcheck.SparseFloatVectorDefaultMetricType, metricType)
})
t.Run("normal case, ivf flat", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
p := &autoIndexConfig{
IndexParams: ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.NotPanics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
})
t.Run("normal case, ivf flat", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
p := &autoIndexConfig{
IndexParams: ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.NotPanics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
})
t.Run("normal case, diskann", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"index_type": "DISKANN"}`)
p := &autoIndexConfig{
IndexParams: ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.NotPanics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
})
t.Run("normal case, bin flat", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"index_type": "BIN_FLAT"}`)
p := &autoIndexConfig{
IndexParams: ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.NotPanics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType)
})
t.Run("normal case, bin ivf flat", func(t *testing.T) {
mgr := config.NewManager()
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "BIN_IVF_FLAT"}`)
p := &autoIndexConfig{
IndexParams: ParamItem{
Key: "autoIndex.params.build",
},
}
p.IndexParams.Init(mgr)
assert.NotPanics(t, func() {
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
})
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
assert.True(t, exist)
assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType)
})
}
func TestScalarAutoIndexParams_build(t *testing.T) {

View File

@ -65,7 +65,7 @@ type ComponentParam struct {
CommonCfg commonConfig
QuotaConfig quotaConfig
AutoIndexConfig autoIndexConfig
AutoIndexConfig AutoIndexConfig
GpuConfig gpuConfig
TraceCfg traceConfig

View File

@ -607,7 +607,7 @@ func TestCreateIndexJsonField(t *testing.T) {
// create vector index on json field
idx := index.NewSCANNIndex(entity.L2, 8, false)
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultJSONFieldName, idx).WithIndexName("json_index"))
common.CheckErr(t, err, false, "data type should be FloatVector, Float16Vector or BFloat16Vector")
common.CheckErr(t, err, false, "index SCANN only supports vector data type")
// create scalar index on json field
type scalarIndexError struct {
@ -653,7 +653,7 @@ func TestCreateUnsupportedIndexArrayField(t *testing.T) {
if field.DataType == entity.FieldTypeArray {
// create vector index
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, vectorIdx).WithIndexName("vector_index"))
common.CheckErr(t, err1, false, "data type should be FloatVector, Float16Vector or BFloat16Vector")
common.CheckErr(t, err1, false, "index SCANN only supports vector data type")
// create scalar index
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxErr.idx))
@ -840,11 +840,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) {
for _, drb := range []float64{-0.3, 1.3} {
idxInverted := index.NewSparseInvertedIndex(entity.IP, drb)
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted))
common.CheckErr(t, err, false, "must be in range [0, 1)")
common.CheckErr(t, err, false, "Out of range in json: param 'drop_ratio_build'")
idxWand := index.NewSparseWANDIndex(entity.IP, drb)
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand))
common.CheckErr(t, err1, false, "must be in range [0, 1)")
common.CheckErr(t, err1, false, "Out of range in json: param 'drop_ratio_build'")
}
}
@ -944,20 +944,22 @@ func TestCreateVectorIndexScalarField(t *testing.T) {
// create float vector index on scalar field
for _, idx := range hp.GenAllFloatIndex(entity.COSINE) {
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idx))
common.CheckErr(t, err, false, "can't build hnsw in not vector type",
"data type should be FloatVector, Float16Vector or BFloat16Vector")
expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idx.IndexType())
common.CheckErr(t, err, false, expErrorMsg)
}
// create binary vector index on scalar field
for _, idxBinary := range []index.Index{index.NewBinFlatIndex(entity.IP), index.NewBinIvfFlatIndex(entity.COSINE, 64)} {
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxBinary))
common.CheckErr(t, err, false, "binary vector is only supported")
expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idxBinary.IndexType())
common.CheckErr(t, err, false, expErrorMsg)
}
// create sparse vector index on scalar field
for _, idxSparse := range []index.Index{index.NewSparseInvertedIndex(entity.IP, 0.2), index.NewSparseWANDIndex(entity.IP, 0.3)} {
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxSparse))
common.CheckErr(t, err, false, "only sparse float vector is supported for the specified index")
expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idxSparse.IndexType())
common.CheckErr(t, err, false, expErrorMsg)
}
}
}
@ -972,7 +974,7 @@ func TestCreateIndexInvalidParams(t *testing.T) {
_, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
// invalid IvfFlat nlist [1, 65536]
errMsg := "nlist out of range: [1, 65536]"
errMsg := "Out of range in json: param 'nlist'"
for _, invalidNlist := range []int{0, -1, 65536 + 1} {
// IvfFlat
idxIvfFlat := index.NewIvfFlatIndex(entity.L2, invalidNlist)
@ -997,7 +999,7 @@ func TestCreateIndexInvalidParams(t *testing.T) {
// IvfFlat
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 8, invalidNBits)
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxIvfPq))
common.CheckErr(t, err, false, "parameter `nbits` out of range, expect range [1,64]")
common.CheckErr(t, err, false, "Out of range in json: param 'nbits'")
}
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 7, 8)
@ -1009,13 +1011,13 @@ func TestCreateIndexInvalidParams(t *testing.T) {
// IvfFlat
idxHnsw := index.NewHNSWIndex(entity.L2, invalidM, 96)
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw))
common.CheckErr(t, err, false, "M out of range: [1, 2048]")
common.CheckErr(t, err, false, "Out of range in json: param 'M'")
}
for _, invalidEfConstruction := range []int{0, 2147483647 + 1} {
// IvfFlat
idxHnsw := index.NewHNSWIndex(entity.L2, 8, invalidEfConstruction)
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw))
common.CheckErr(t, err, false, "efConstruction out of range: [1, 2147483647]")
common.CheckErr(t, err, false, "Out of range in json: param 'efConstruction'", "integer value out of range, key: 'efConstruction'")
}
}

View File

@ -33,10 +33,10 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/importutilv2"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/tests/integration"
@ -66,7 +66,7 @@ func (s *BulkInsertSuite) SetupTest() {
s.autoID = false
s.vecType = schemapb.DataType_FloatVector
s.indexType = indexparamcheck.IndexHNSW
s.indexType = "HNSW"
s.metricType = metric.L2
}
@ -225,29 +225,29 @@ func (s *BulkInsertSuite) TestMultiFileTypes() {
s.fileType = fileType
s.vecType = schemapb.DataType_BinaryVector
s.indexType = indexparamcheck.IndexFaissBinIvfFlat
s.indexType = "BIN_IVF_FLAT"
s.metricType = metric.HAMMING
s.run()
s.vecType = schemapb.DataType_FloatVector
s.indexType = indexparamcheck.IndexHNSW
s.indexType = "HNSW"
s.metricType = metric.L2
s.run()
s.vecType = schemapb.DataType_Float16Vector
s.indexType = indexparamcheck.IndexHNSW
s.indexType = "HNSW"
s.metricType = metric.L2
s.run()
s.vecType = schemapb.DataType_BFloat16Vector
s.indexType = indexparamcheck.IndexHNSW
s.indexType = "HNSW"
s.metricType = metric.L2
s.run()
// TODO: not support numpy for SparseFloatVector by now
if fileType != importutilv2.Numpy {
s.vecType = schemapb.DataType_SparseFloatVector
s.indexType = indexparamcheck.IndexSparseWand
s.indexType = "SPARSE_WAND"
s.metricType = metric.IP
s.run()
}

View File

@ -26,23 +26,22 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
)
const (
IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat
IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ
IndexFaissIDMap = indexparamcheck.IndexFaissIDMap
IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat
IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ
IndexScaNN = indexparamcheck.IndexScaNN
IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8
IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap
IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat
IndexHNSW = indexparamcheck.IndexHNSW
IndexDISKANN = indexparamcheck.IndexDISKANN
IndexSparseInvertedIndex = indexparamcheck.IndexSparseInverted
IndexSparseWand = indexparamcheck.IndexSparseWand
IndexRaftIvfFlat = "GPU_IVF_FLAT"
IndexRaftIvfPQ = "GPU_IVF_PQ"
IndexFaissIDMap = "FLAT"
IndexFaissIvfFlat = "IVF_FLAT"
IndexFaissIvfPQ = "IVF_PQ"
IndexScaNN = "SCANN"
IndexFaissIvfSQ8 = "IVF_SQ8"
IndexFaissBinIDMap = "BIN_FLAT"
IndexFaissBinIvfFlat = "BIN_IVF_FLAT"
IndexHNSW = "HNSW"
IndexDISKANN = "DISKANN"
IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX"
IndexSparseWand = "SPARSE_WAND"
)
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {

View File

@ -2103,6 +2103,8 @@ def gen_simple_index():
continue
elif ct.all_index_types[i] in ct.sparse_support:
continue
elif ct.all_index_types[i] in ct.gpu_support:
continue
dic = {"index_type": ct.all_index_types[i], "metric_type": "L2"}
dic.update({"params": ct.default_all_indexes_params[i]})
index_params.append(dic)

View File

@ -244,6 +244,7 @@ default_all_search_params_params = [{}, {"nprobe": 32}, {"nprobe": 32}, {"nprobe
Handler_type = ["GRPC", "HTTP"]
binary_support = ["BIN_FLAT", "BIN_IVF_FLAT"]
sparse_support = ["SPARSE_INVERTED_INDEX", "SPARSE_WAND"]
gpu_support = ["GPU_IVF_FLAT", "GPU_IVF_PQ"]
default_L0_metric = "COSINE"
float_metrics = ["L2", "IP", "COSINE"]
binary_metrics = ["JACCARD", "HAMMING", "SUBSTRUCTURE", "SUPERSTRUCTURE"]

View File

@ -57,6 +57,8 @@ default_index_params = [
def create_target_index(index, field_name):
index["field_name"] = field_name
def gpu_support():
return ["GPU_IVF_FLAT", "GPU_IVF_PQ"]
def binary_support():
return ["BIN_FLAT", "BIN_IVF_FLAT"]
@ -764,6 +766,8 @@ def gen_simple_index():
for i in range(len(all_index_types)):
if all_index_types[i] in binary_support():
continue
if all_index_types[i] in gpu_support():
continue
dic = {"index_type": all_index_types[i], "metric_type": "L2"}
dic.update({"params": default_index_params[i]})
index_params.append(dic)