mirror of https://github.com/milvus-io/milvus.git
enhance: add an unify vector index config checker (#36844)
issue: #34298 Signed-off-by: xianliang.li <xianliang.li@zilliz.com>pull/37180/head
parent
eeb67a3845
commit
d7b2ffe5aa
|
@ -33,6 +33,7 @@ func NewDiskANNIndex(metricType MetricType) Index {
|
|||
return &diskANNIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
indexType: DISKANN,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ func NewFlatIndex(metricType MetricType) Index {
|
|||
return flatIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
indexType: Flat,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -54,6 +55,7 @@ func NewBinFlatIndex(metricType MetricType) Index {
|
|||
return binFlatIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
indexType: BinFlat,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
grpcproxy "github.com/milvus-io/milvus/internal/distributed/proxy"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
|
@ -50,6 +51,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
|
|||
}
|
||||
|
||||
func (n *Proxy) Prepare() error {
|
||||
indexparamcheck.ValidateParamTable()
|
||||
return n.svr.Prepare()
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,79 @@
|
|||
#include "index/IndexFactory.h"
|
||||
#include "pb/index_cgo_msg.pb.h"
|
||||
|
||||
CStatus
|
||||
ValidateIndexParams(const char* index_type,
|
||||
enum CDataType data_type,
|
||||
const uint8_t* serialized_index_params,
|
||||
const uint64_t length) {
|
||||
try {
|
||||
auto index_params =
|
||||
std::make_unique<milvus::proto::indexcgo::IndexParams>();
|
||||
auto res =
|
||||
index_params->ParseFromArray(serialized_index_params, length);
|
||||
AssertInfo(res, "Unmarshall index params failed");
|
||||
|
||||
knowhere::Json json;
|
||||
|
||||
for (size_t i = 0; i < index_params->params_size(); i++) {
|
||||
auto& param = index_params->params(i);
|
||||
json[param.key()] = param.value();
|
||||
}
|
||||
|
||||
milvus::DataType dataType(static_cast<milvus::DataType>(data_type));
|
||||
|
||||
knowhere::Status status;
|
||||
std::string error_msg;
|
||||
if (dataType == milvus::DataType::VECTOR_BINARY) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_FLOAT) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_BFLOAT16) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::bf16>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_FLOAT16) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::fp16>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else {
|
||||
status = knowhere::Status::invalid_args;
|
||||
}
|
||||
CStatus cStatus;
|
||||
if (status == knowhere::Status::success) {
|
||||
cStatus.error_code = milvus::Success;
|
||||
cStatus.error_msg = "";
|
||||
} else {
|
||||
cStatus.error_code = milvus::ConfigInvalid;
|
||||
cStatus.error_msg = strdup(error_msg.c_str());
|
||||
}
|
||||
return cStatus;
|
||||
} catch (std::exception& e) {
|
||||
auto cStatus = CStatus();
|
||||
cStatus.error_code = milvus::UnexpectedError;
|
||||
cStatus.error_msg = strdup(e.what());
|
||||
return cStatus;
|
||||
}
|
||||
}
|
||||
|
||||
int
|
||||
GetIndexListSize() {
|
||||
return knowhere::IndexFactory::Instance().GetIndexFeatures().size();
|
||||
|
|
|
@ -17,6 +17,12 @@ extern "C" {
|
|||
#include <stdbool.h>
|
||||
#include "common/type_c.h"
|
||||
|
||||
CStatus
|
||||
ValidateIndexParams(const char* index_type,
|
||||
enum CDataType data_type,
|
||||
const uint8_t* index_params,
|
||||
const uint64_t length);
|
||||
|
||||
int
|
||||
GetIndexListSize();
|
||||
|
||||
|
|
|
@ -34,11 +34,11 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/workerpb"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
|
|
|
@ -28,11 +28,11 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
|
|
|
@ -42,9 +42,9 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/workerpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
|
@ -620,13 +620,13 @@ func TestServer_AlterIndex(t *testing.T) {
|
|||
s.stateCode.Store(commonpb.StateCode_Healthy)
|
||||
|
||||
t.Run("mmap_unsupported", func(t *testing.T) {
|
||||
indexParams[0].Value = indexparamcheck.IndexRaftCagra
|
||||
indexParams[0].Value = "GPU_CAGRA"
|
||||
|
||||
resp, err := s.AlterIndex(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
|
||||
|
||||
indexParams[0].Value = indexparamcheck.IndexFaissIvfFlat
|
||||
indexParams[0].Value = "IVF_FLAT"
|
||||
})
|
||||
|
||||
t.Run("param_value_invalied", func(t *testing.T) {
|
||||
|
|
|
@ -28,13 +28,13 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
@ -475,25 +475,18 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
|
|||
if err := fillDimension(field, indexParams); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// used only for checker, should be deleted after checking
|
||||
indexParams[IsSparseKey] = "true"
|
||||
}
|
||||
|
||||
if err := checker.CheckValidDataType(field); err != nil {
|
||||
if err := checker.CheckValidDataType(indexType, field); err != nil {
|
||||
log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String()))
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checker.CheckTrain(indexParams); err != nil {
|
||||
if err := checker.CheckTrain(field.DataType, indexParams); err != nil {
|
||||
log.Info("create index with invalid parameters", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if isSparse {
|
||||
delete(indexParams, IsSparseKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -35,9 +35,9 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
|
|
|
@ -40,6 +40,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
|
@ -48,7 +49,6 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
|
|
|
@ -29,11 +29,11 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
|
|
|
@ -24,8 +24,8 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
|
|
@ -52,12 +52,12 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/querynodev2/segments/state"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/cgo"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||
|
|
|
@ -33,11 +33,11 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/initcore"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
|
|
|
@ -29,12 +29,12 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/querycoordv2/params"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
|
|
|
@ -9,11 +9,11 @@ type AUTOINDEXChecker struct {
|
|||
baseChecker
|
||||
}
|
||||
|
||||
func (c *AUTOINDEXChecker) CheckTrain(params map[string]string) error {
|
||||
func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AUTOINDEXChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c *AUTOINDEXChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -19,29 +19,17 @@ package indexparamcheck
|
|||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type baseChecker struct{}
|
||||
|
||||
func (c baseChecker) CheckTrain(params map[string]string) error {
|
||||
// vector dimension should be checked on collection creation. this is just some basic check
|
||||
isSparse := false
|
||||
if val, exist := params[common.IsSparseKey]; exist {
|
||||
val = strings.ToLower(val)
|
||||
if val != "true" && val != "false" {
|
||||
return fmt.Errorf("invalid is_sparse value: %s, must be true or false", val)
|
||||
}
|
||||
if val == "true" {
|
||||
isSparse = true
|
||||
}
|
||||
}
|
||||
if isSparse {
|
||||
func (c baseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if typeutil.IsSparseFloatVectorType(dataType) {
|
||||
if !CheckStrByValues(params, Metric, SparseMetrics) {
|
||||
return fmt.Errorf("metric type not found or not supported for sparse float vectors, supported: %v", SparseMetrics)
|
||||
}
|
||||
|
@ -55,13 +43,13 @@ func (c baseChecker) CheckTrain(params map[string]string) error {
|
|||
}
|
||||
|
||||
// CheckValidDataType check whether the field data type is supported for the index type
|
||||
func (c baseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c baseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c baseChecker) SetDefaultMetricTypeIfNotExist(m map[string]string, dType schemapb.DataType) {}
|
||||
func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {}
|
||||
|
||||
func (c baseChecker) StaticCheck(params map[string]string) error {
|
||||
func (c baseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
return errors.New("unsupported index type")
|
||||
}
|
||||
|
|
@ -21,7 +21,7 @@ func Test_baseChecker_CheckTrain(t *testing.T) {
|
|||
}
|
||||
sparseParamsWithoutDim := map[string]string{
|
||||
Metric: metric.IP,
|
||||
common.IsSparseKey: "tRue",
|
||||
common.IsSparseKey: "True",
|
||||
}
|
||||
sparseParamsWrongMetric := map[string]string{
|
||||
Metric: metric.L2,
|
||||
|
@ -42,9 +42,15 @@ func Test_baseChecker_CheckTrain(t *testing.T) {
|
|||
{badSparseParams, false},
|
||||
}
|
||||
|
||||
c := newBaseChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "HNSW"
|
||||
var err error
|
||||
if test.params[common.IsSparseKey] == "True" {
|
||||
err = c.CheckTrain(schemapb.DataType_SparseFloatVector, test.params)
|
||||
} else {
|
||||
err = c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
}
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -115,7 +121,7 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) {
|
|||
c := newBaseChecker()
|
||||
for _, test := range cases {
|
||||
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
||||
err := c.CheckValidDataType(fieldSchema)
|
||||
err := c.CheckValidDataType("FLAT", fieldSchema)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -126,5 +132,5 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) {
|
|||
|
||||
func Test_baseChecker_StaticCheck(t *testing.T) {
|
||||
// TODO
|
||||
assert.Error(t, newBaseChecker().StaticCheck(nil))
|
||||
assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, nil))
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package indexparamcheck
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
||||
type binFlatChecker struct {
|
||||
binaryVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c binFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
return c.binaryVectorBaseChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func (c binFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
return c.staticCheck(params)
|
||||
}
|
||||
|
||||
func newBinFlatChecker() IndexChecker {
|
||||
return &binFlatChecker{}
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -64,9 +65,10 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) {
|
|||
{p7, true},
|
||||
}
|
||||
|
||||
c := newBinFlatChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "BINFLAT"
|
||||
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -134,10 +136,10 @@ func Test_binFlatChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newBinFlatChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
|
||||
for _, test := range cases {
|
||||
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
||||
err := c.CheckValidDataType(fieldSchema)
|
||||
err := c.CheckValidDataType("BINFLAT", fieldSchema)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -2,13 +2,15 @@ package indexparamcheck
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
type binIVFFlatChecker struct {
|
||||
binaryVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c binIVFFlatChecker) StaticCheck(params map[string]string) error {
|
||||
func (c binIVFFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
if !CheckStrByValues(params, Metric, BinIvfMetrics) {
|
||||
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinIvfMetrics)
|
||||
}
|
||||
|
@ -20,12 +22,12 @@ func (c binIVFFlatChecker) StaticCheck(params map[string]string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c binIVFFlatChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil {
|
||||
func (c binIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.binaryVectorBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.StaticCheck(params)
|
||||
return c.StaticCheck(schemapb.DataType_BinaryVector, params)
|
||||
}
|
||||
|
||||
func newBinIVFFlatChecker() IndexChecker {
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -115,9 +116,10 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
|
|||
{p7, false},
|
||||
}
|
||||
|
||||
c := newBinIVFFlatChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT")
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "BIN_IVF_FLAT"
|
||||
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -185,10 +187,10 @@ func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newBinIVFFlatChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT")
|
||||
for _, test := range cases {
|
||||
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
||||
err := c.CheckValidDataType(fieldSchema)
|
||||
err := c.CheckValidDataType("BIN_IVF_FLAT", fieldSchema)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -19,22 +19,22 @@ func (c binaryVectorBaseChecker) staticCheck(params map[string]string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c binaryVectorBaseChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.baseChecker.CheckTrain(params); err != nil {
|
||||
func (c binaryVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.baseChecker.CheckTrain(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.staticCheck(params)
|
||||
}
|
||||
|
||||
func (c binaryVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c binaryVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
if field.GetDataType() != schemapb.DataType_BinaryVector {
|
||||
return fmt.Errorf("binary vector is only supported")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
|
||||
func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||
setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType)
|
||||
}
|
||||
|
|
@ -67,10 +67,10 @@ func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newBinaryVectorBaseChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
|
||||
for _, test := range cases {
|
||||
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
||||
err := c.CheckValidDataType(fieldSchema)
|
||||
err := c.CheckValidDataType("BINFLAT", fieldSchema)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -0,0 +1,33 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_BitmapIndexChecker(t *testing.T) {
|
||||
c := newBITMAPChecker()
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true}))
|
||||
}
|
|
@ -11,11 +11,11 @@ type BITMAPChecker struct {
|
|||
scalarIndexChecker
|
||||
}
|
||||
|
||||
func (c *BITMAPChecker) CheckTrain(params map[string]string) error {
|
||||
return c.scalarIndexChecker.CheckTrain(params)
|
||||
func (c *BITMAPChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func (c *BITMAPChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c *BITMAPChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
if field.IsPrimaryKey {
|
||||
return fmt.Errorf("create bitmap index on primary key not supported")
|
||||
}
|
|
@ -3,6 +3,8 @@ package indexparamcheck
|
|||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
// diskannChecker checks if an diskann index can be built.
|
||||
|
@ -10,8 +12,8 @@ type cagraChecker struct {
|
|||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c *cagraChecker) CheckTrain(params map[string]string) error {
|
||||
err := c.baseChecker.CheckTrain(params)
|
||||
func (c *cagraChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
err := c.baseChecker.CheckTrain(dataType, params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -54,7 +56,7 @@ func (c *cagraChecker) CheckTrain(params map[string]string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c cagraChecker) StaticCheck(params map[string]string) error {
|
||||
func (c cagraChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
return c.staticCheck(params)
|
||||
}
|
||||
|
|
@ -6,6 +6,8 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -101,9 +103,13 @@ func Test_cagraChecker_CheckTrain(t *testing.T) {
|
|||
{p14, false},
|
||||
}
|
||||
|
||||
c := newCagraChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_CAGRA")
|
||||
if c == nil {
|
||||
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||
return
|
||||
}
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -20,6 +20,8 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
)
|
||||
|
||||
type IndexCheckerMgr interface {
|
||||
|
@ -34,36 +36,19 @@ type indexCheckerMgrImpl struct {
|
|||
|
||||
func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, error) {
|
||||
mgr.once.Do(mgr.registerIndexChecker)
|
||||
|
||||
// Unify the vector index checker
|
||||
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) {
|
||||
return mgr.checkers[IndexVector], nil
|
||||
}
|
||||
adapter, ok := mgr.checkers[indexType]
|
||||
if ok {
|
||||
return adapter, nil
|
||||
}
|
||||
return nil, errors.New("Can not find conf adapter: " + indexType)
|
||||
return nil, errors.New("Can not find index: " + indexType + " , please check")
|
||||
}
|
||||
|
||||
func (mgr *indexCheckerMgrImpl) registerIndexChecker() {
|
||||
mgr.checkers[IndexRaftIvfFlat] = newRaftIVFFlatChecker()
|
||||
mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker()
|
||||
mgr.checkers[IndexRaftCagra] = newCagraChecker()
|
||||
mgr.checkers[IndexRaftBruteForce] = newRaftBruteForceChecker()
|
||||
mgr.checkers[IndexFaissIDMap] = newFlatChecker()
|
||||
mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker()
|
||||
mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker()
|
||||
mgr.checkers[IndexScaNN] = newScaNNChecker()
|
||||
mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker()
|
||||
mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker()
|
||||
mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker()
|
||||
mgr.checkers[IndexHNSW] = newHnswChecker()
|
||||
mgr.checkers[IndexDISKANN] = newDiskannChecker()
|
||||
mgr.checkers[IndexSparseInverted] = newSparseInvertedIndexChecker()
|
||||
mgr.checkers[IndexFaissHNSW] = newFloatVectorBaseChecker()
|
||||
mgr.checkers[IndexFaissHNSWPQ] = newFloatVectorBaseChecker()
|
||||
mgr.checkers[IndexFaissHNSWSQ] = newFloatVectorBaseChecker()
|
||||
mgr.checkers[IndexFaissHNSWPRQ] = newFloatVectorBaseChecker()
|
||||
// WAND doesn't have more index params than sparse inverted index, thus
|
||||
// using the same checker.
|
||||
mgr.checkers[IndexSparseWand] = newSparseInvertedIndexChecker()
|
||||
mgr.checkers[IndexVector] = newVecIndexChecker()
|
||||
mgr.checkers[IndexINVERTED] = newINVERTEDChecker()
|
||||
mgr.checkers[IndexSTLSORT] = newSTLSORTChecker()
|
||||
mgr.checkers["Asceneding"] = newSTLSORTChecker()
|
|
@ -29,52 +29,52 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) {
|
|||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, adapter)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
|
||||
adapter, err = adapterMgr.GetChecker("FLAT")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*flatChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
|
||||
adapter, err = adapterMgr.GetChecker("IVF_FLAT")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*ivfBaseChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexScaNN)
|
||||
adapter, err = adapterMgr.GetChecker("SCANN")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*scaNNChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
|
||||
adapter, err = adapterMgr.GetChecker("IVF_PQ")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*ivfPQChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
|
||||
adapter, err = adapterMgr.GetChecker("IVF_SQ8")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*ivfSQChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
|
||||
adapter, err = adapterMgr.GetChecker("BIN_FLAT")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*binFlatChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
|
||||
adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*binIVFFlatChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexHNSW)
|
||||
adapter, err = adapterMgr.GetChecker("HNSW")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*hnswChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
|
@ -89,52 +89,52 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) {
|
|||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, adapter)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
|
||||
adapter, err = adapterMgr.GetChecker("FLAT")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*flatChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
|
||||
adapter, err = adapterMgr.GetChecker("IVF_FLAT")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*ivfBaseChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexScaNN)
|
||||
adapter, err = adapterMgr.GetChecker("SCANN")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*scaNNChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
|
||||
adapter, err = adapterMgr.GetChecker("IVF_PQ")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*ivfPQChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
|
||||
adapter, err = adapterMgr.GetChecker("IVF_SQ8")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*ivfSQChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
|
||||
adapter, err = adapterMgr.GetChecker("BIN_FLAT")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*binFlatChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
|
||||
adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*binIVFFlatChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexHNSW)
|
||||
adapter, err = adapterMgr.GetChecker("HNSW")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*hnswChecker)
|
||||
_, ok = adapter.(*vecIndexChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
|
@ -146,7 +146,7 @@ func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
adapter, err := mgr.GetChecker(IndexHNSW)
|
||||
adapter, err := mgr.GetChecker("HNSW")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
}()
|
|
@ -65,7 +65,7 @@ var (
|
|||
CagraBuildAlgoTypes = []string{CargaBuildAlgoIVFPQ, CargaBuildAlgoNNDESCENT}
|
||||
supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
|
||||
supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
|
||||
SparseMetrics = []string{metric.IP} // const
|
||||
SparseMetrics = []string{metric.IP, metric.BM25} // const
|
||||
)
|
||||
|
||||
const (
|
|
@ -1,11 +1,13 @@
|
|||
package indexparamcheck
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
||||
// diskannChecker checks if an diskann index can be built.
|
||||
type diskannChecker struct {
|
||||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c diskannChecker) StaticCheck(params map[string]string) error {
|
||||
func (c diskannChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
return c.staticCheck(params)
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -72,9 +73,10 @@ func Test_diskannChecker_CheckTrain(t *testing.T) {
|
|||
{p7, false},
|
||||
}
|
||||
|
||||
c := newDiskannChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN")
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "DISKANN"
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -142,9 +144,9 @@ func Test_diskannChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newDiskannChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN")
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
||||
err := c.CheckValidDataType("DISKANN", &schemapb.FieldSchema{DataType: test.dType})
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -1,10 +1,12 @@
|
|||
package indexparamcheck
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
||||
type flatChecker struct {
|
||||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c flatChecker) StaticCheck(m map[string]string) error {
|
||||
func (c flatChecker) StaticCheck(dataType schemapb.DataType, m map[string]string) error {
|
||||
return c.staticCheck(m)
|
||||
}
|
||||
|
|
@ -6,6 +6,8 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -52,9 +54,10 @@ func Test_flatChecker_CheckTrain(t *testing.T) {
|
|||
{p7, false},
|
||||
}
|
||||
|
||||
c := newFlatChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT")
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "FLAT"
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -89,9 +92,10 @@ func Test_flatChecker_StaticCheck(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newFlatChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT")
|
||||
for _, test := range cases {
|
||||
err := c.StaticCheck(test.params)
|
||||
test.params[common.IndexTypeKey] = "FLAT"
|
||||
err := c.StaticCheck(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -20,22 +20,22 @@ func (c floatVectorBaseChecker) staticCheck(params map[string]string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.baseChecker.CheckTrain(params); err != nil {
|
||||
func (c floatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.baseChecker.CheckTrain(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.staticCheck(params)
|
||||
}
|
||||
|
||||
func (c floatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c floatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
if !typeutil.IsDenseFloatVectorType(field.GetDataType()) {
|
||||
return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
|
||||
func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
|
||||
}
|
||||
|
|
@ -63,13 +63,13 @@ func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: false,
|
||||
errIsNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
c := newFloatVectorBaseChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
||||
err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType})
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -12,7 +12,7 @@ type hnswChecker struct {
|
|||
baseChecker
|
||||
}
|
||||
|
||||
func (c hnswChecker) StaticCheck(params map[string]string) error {
|
||||
func (c hnswChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
|
||||
return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction)
|
||||
}
|
||||
|
@ -25,21 +25,21 @@ func (c hnswChecker) StaticCheck(params map[string]string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c hnswChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.StaticCheck(params); err != nil {
|
||||
func (c hnswChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.StaticCheck(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.baseChecker.CheckTrain(params)
|
||||
return c.baseChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func (c hnswChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c hnswChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
if !typeutil.IsVectorType(field.GetDataType()) {
|
||||
return fmt.Errorf("can't build hnsw in not vector type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c hnswChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
|
||||
func (c hnswChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||
if typeutil.IsDenseFloatVectorType(dType) {
|
||||
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
|
||||
} else if typeutil.IsSparseFloatVectorType(dType) {
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -88,13 +89,19 @@ func Test_hnswChecker_CheckTrain(t *testing.T) {
|
|||
{p3, true},
|
||||
{p4, true},
|
||||
{p5, true},
|
||||
{p6, false},
|
||||
{p7, false},
|
||||
{p6, true},
|
||||
{p7, true},
|
||||
}
|
||||
|
||||
c := newHnswChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "HNSW"
|
||||
var err error
|
||||
if CheckStrByValues(test.params, common.MetricTypeKey, BinaryVectorMetrics) {
|
||||
err = c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
|
||||
} else {
|
||||
err = c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
}
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -162,9 +169,9 @@ func Test_hnswChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newHnswChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
||||
err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType})
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -200,14 +207,14 @@ func Test_hnswChecker_SetDefaultMetricType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newHnswChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||
for _, test := range cases {
|
||||
p := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
}
|
||||
c.SetDefaultMetricTypeIfNotExist(p, test.dType)
|
||||
c.SetDefaultMetricTypeIfNotExist(test.dType, p)
|
||||
assert.Equal(t, p[Metric], test.metricType)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_HybridIndexChecker(t *testing.T) {
|
||||
c := newHYBRIDChecker()
|
||||
|
||||
assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{"bitmap_cardinality_limit": "100"}))
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
|
||||
assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{}))
|
||||
assert.Error(t, c.CheckTrain(schemapb.DataType_Float, map[string]string{"bitmap_cardinality_limit": "0"}))
|
||||
assert.Error(t, c.CheckTrain(schemapb.DataType_Double, map[string]string{"bitmap_cardinality_limit": "2000"}))
|
||||
}
|
|
@ -12,15 +12,15 @@ type HYBRIDChecker struct {
|
|||
scalarIndexChecker
|
||||
}
|
||||
|
||||
func (c *HYBRIDChecker) CheckTrain(params map[string]string) error {
|
||||
func (c *HYBRIDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if !CheckIntByRange(params, common.BitmapCardinalityLimitKey, 1, MaxBitmapCardinalityLimit) {
|
||||
return fmt.Errorf("failed to check bitmap cardinality limit, should be larger than 0 and smaller than %d",
|
||||
MaxBitmapCardinalityLimit)
|
||||
}
|
||||
return c.scalarIndexChecker.CheckTrain(params)
|
||||
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func (c *HYBRIDChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c *HYBRIDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
mainType := field.GetDataType()
|
||||
elemType := field.GetElementType()
|
||||
if !typeutil.IsBoolType(mainType) && !typeutil.IsIntegerType(mainType) &&
|
|
@ -21,8 +21,8 @@ import (
|
|||
)
|
||||
|
||||
type IndexChecker interface {
|
||||
CheckTrain(map[string]string) error
|
||||
CheckValidDataType(field *schemapb.FieldSchema) error
|
||||
SetDefaultMetricTypeIfNotExist(map[string]string, schemapb.DataType)
|
||||
StaticCheck(map[string]string) error
|
||||
CheckTrain(schemapb.DataType, map[string]string) error
|
||||
CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error
|
||||
SetDefaultMetricTypeIfNotExist(schemapb.DataType, map[string]string)
|
||||
StaticCheck(schemapb.DataType, map[string]string) error
|
||||
}
|
|
@ -15,6 +15,7 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
|
@ -23,31 +24,7 @@ type IndexType = string
|
|||
|
||||
// IndexType definitions
|
||||
const (
|
||||
// vector index
|
||||
IndexGpuBF IndexType = "GPU_BRUTE_FORCE"
|
||||
IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT"
|
||||
IndexRaftIvfPQ IndexType = "GPU_IVF_PQ"
|
||||
IndexRaftCagra IndexType = "GPU_CAGRA"
|
||||
IndexRaftBruteForce IndexType = "GPU_BRUTE_FORCE"
|
||||
IndexFaissIDMap IndexType = "FLAT" // no index is built.
|
||||
IndexFaissIvfFlat IndexType = "IVF_FLAT"
|
||||
IndexFaissIvfPQ IndexType = "IVF_PQ"
|
||||
IndexScaNN IndexType = "SCANN"
|
||||
IndexFaissIvfSQ8 IndexType = "IVF_SQ8"
|
||||
IndexFaissBinIDMap IndexType = "BIN_FLAT"
|
||||
IndexFaissBinIvfFlat IndexType = "BIN_IVF_FLAT"
|
||||
IndexHNSW IndexType = "HNSW"
|
||||
IndexDISKANN IndexType = "DISKANN"
|
||||
IndexSparseInverted IndexType = "SPARSE_INVERTED_INDEX"
|
||||
IndexSparseWand IndexType = "SPARSE_WAND"
|
||||
// For temporary use, will be removed in the future.
|
||||
// 1. All Index related param check will be moved to Knowhere recently.
|
||||
// 2. FAISS_HNSW_xxx will be rename to HNSW_xxx after QA test. We keep the original name for comparison purpose.
|
||||
// TODO: @liliu-z @foxspy
|
||||
IndexFaissHNSW IndexType = "FAISS_HNSW_FLAT"
|
||||
IndexFaissHNSWPQ IndexType = "FAISS_HNSW_PQ"
|
||||
IndexFaissHNSWSQ IndexType = "FAISS_HNSW_SQ"
|
||||
IndexFaissHNSWPRQ IndexType = "FAISS_HNSW_PRQ"
|
||||
IndexVector IndexType = "VECINDEX"
|
||||
|
||||
// scalar index
|
||||
IndexSTLSORT IndexType = "STL_SORT"
|
||||
|
@ -66,28 +43,12 @@ func IsScalarIndexType(indexType IndexType) bool {
|
|||
}
|
||||
|
||||
func IsGpuIndex(indexType IndexType) bool {
|
||||
return indexType == IndexGpuBF ||
|
||||
indexType == IndexRaftIvfFlat ||
|
||||
indexType == IndexRaftIvfPQ ||
|
||||
indexType == IndexRaftCagra
|
||||
return vecindexmgr.GetVecIndexMgrInstance().IsGPUVecIndex(indexType)
|
||||
}
|
||||
|
||||
// IsVectorMmapIndex check if the vector index can be mmaped
|
||||
func IsVectorMmapIndex(indexType IndexType) bool {
|
||||
return indexType == IndexFaissIDMap ||
|
||||
indexType == IndexFaissIvfFlat ||
|
||||
indexType == IndexFaissIvfPQ ||
|
||||
indexType == IndexFaissIvfSQ8 ||
|
||||
indexType == IndexFaissBinIDMap ||
|
||||
indexType == IndexFaissBinIvfFlat ||
|
||||
indexType == IndexHNSW ||
|
||||
indexType == IndexFaissHNSW ||
|
||||
indexType == IndexFaissHNSWPQ ||
|
||||
indexType == IndexFaissHNSWSQ ||
|
||||
indexType == IndexFaissHNSWPRQ ||
|
||||
indexType == IndexScaNN ||
|
||||
indexType == IndexSparseInverted ||
|
||||
indexType == IndexSparseWand
|
||||
return vecindexmgr.GetVecIndexMgrInstance().IsMMapSupported(indexType)
|
||||
}
|
||||
|
||||
func IsOffsetCacheSupported(indexType IndexType) bool {
|
||||
|
@ -95,7 +56,7 @@ func IsOffsetCacheSupported(indexType IndexType) bool {
|
|||
}
|
||||
|
||||
func IsDiskIndex(indexType IndexType) bool {
|
||||
return indexType == IndexDISKANN
|
||||
return vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType)
|
||||
}
|
||||
|
||||
func IsScalarMmapIndex(indexType IndexType) bool {
|
|
@ -12,11 +12,11 @@ type INVERTEDChecker struct {
|
|||
scalarIndexChecker
|
||||
}
|
||||
|
||||
func (c *INVERTEDChecker) CheckTrain(params map[string]string) error {
|
||||
return c.scalarIndexChecker.CheckTrain(params)
|
||||
func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func (c *INVERTEDChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c *INVERTEDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
dType := field.GetDataType()
|
||||
if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) &&
|
||||
!typeutil.IsArrayType(dType) {
|
|
@ -0,0 +1,25 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_INVERTEDIndexChecker(t *testing.T) {
|
||||
c := newINVERTEDChecker()
|
||||
|
||||
assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{}))
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Array}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}))
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package indexparamcheck
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
||||
type ivfBaseChecker struct {
|
||||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c ivfBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
||||
return errOutOfRange(NLIST, MinNList, MaxNList)
|
||||
}
|
||||
|
||||
// skip check number of rows
|
||||
|
||||
return c.floatVectorBaseChecker.staticCheck(params)
|
||||
}
|
||||
|
||||
func (c ivfBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.StaticCheck(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.floatVectorBaseChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func newIVFBaseChecker() IndexChecker {
|
||||
return &ivfBaseChecker{}
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -70,9 +71,10 @@ func Test_ivfBaseChecker_CheckTrain(t *testing.T) {
|
|||
{p7, false},
|
||||
}
|
||||
|
||||
c := newIVFBaseChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT")
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "IVF_FLAT"
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -140,9 +142,9 @@ func Test_ivfBaseChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newIVFBaseChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT")
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
||||
err := c.CheckValidDataType("IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType})
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -3,6 +3,8 @@ package indexparamcheck
|
|||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
// ivfPQChecker checks if a IVF_PQ index can be built.
|
||||
|
@ -11,8 +13,8 @@ type ivfPQChecker struct {
|
|||
}
|
||||
|
||||
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
||||
func (c *ivfPQChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
||||
func (c *ivfPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -141,9 +142,11 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
|
|||
{p7, false},
|
||||
}
|
||||
|
||||
// c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ")
|
||||
c := newIVFPQChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "IVF_PQ"
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -211,9 +214,9 @@ func Test_ivfPQChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newIVFPQChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ")
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
||||
err := c.CheckValidDataType("IVF_PQ", &schemapb.FieldSchema{DataType: test.dType})
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -2,6 +2,8 @@ package indexparamcheck
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
// ivfSQChecker checks if a IVF_SQ index can be built.
|
||||
|
@ -22,11 +24,11 @@ func (c *ivfSQChecker) checkNBits(params map[string]string) error {
|
|||
}
|
||||
|
||||
// CheckTrain returns true if the index can be built with the specific index parameters.
|
||||
func (c *ivfSQChecker) CheckTrain(params map[string]string) error {
|
||||
func (c *ivfSQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.checkNBits(params); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.ivfBaseChecker.CheckTrain(params)
|
||||
return c.ivfBaseChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func newIVFSQChecker() IndexChecker {
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -78,7 +79,6 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
|
|||
}{
|
||||
{validParams, true},
|
||||
{validParamsWithNBits, true},
|
||||
{paramsWithInvalidNBits, false},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{p1, true},
|
||||
|
@ -90,9 +90,10 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
|
|||
{p7, false},
|
||||
}
|
||||
|
||||
c := newIVFSQChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ")
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "IVF_SQ"
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -160,9 +161,9 @@ func Test_ivfSQChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newIVFSQChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ")
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
||||
err := c.CheckValidDataType("IVF_SQ8", &schemapb.FieldSchema{DataType: test.dType})
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -1,14 +1,18 @@
|
|||
package indexparamcheck
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
type raftBruteForceChecker struct {
|
||||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
// raftBrustForceChecker checks if a Brute_Force index can be built.
|
||||
func (c raftBruteForceChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.floatVectorBaseChecker.CheckTrain(params); err != nil {
|
||||
func (c raftBruteForceChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.floatVectorBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
if !CheckStrByValues(params, Metric, RaftMetrics) {
|
|
@ -6,6 +6,9 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -52,9 +55,14 @@ func Test_raftbfChecker_CheckTrain(t *testing.T) {
|
|||
{p7, false},
|
||||
}
|
||||
|
||||
c := newRaftBruteForceChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_BRUTE_FORCE")
|
||||
if c == nil {
|
||||
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||
return
|
||||
}
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "GPU_BRUTE_FORCE"
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -1,6 +1,10 @@
|
|||
package indexparamcheck
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
// raftIVFChecker checks if a RAFT_IVF_Flat index can be built.
|
||||
type raftIVFFlatChecker struct {
|
||||
|
@ -8,8 +12,8 @@ type raftIVFFlatChecker struct {
|
|||
}
|
||||
|
||||
// CheckTrain checks if ivf-flat index can be built with the specific index parameters.
|
||||
func (c *raftIVFFlatChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
||||
func (c *raftIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
if !CheckStrByValues(params, Metric, RaftMetrics) {
|
|
@ -7,6 +7,8 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -84,9 +86,14 @@ func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) {
|
|||
{p9, false},
|
||||
}
|
||||
|
||||
c := newRaftIVFFlatChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT")
|
||||
if c == nil {
|
||||
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||
return
|
||||
}
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "GPU_IVF_FLAT"
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -154,9 +161,13 @@ func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newRaftIVFFlatChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT")
|
||||
if c == nil {
|
||||
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||
return
|
||||
}
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
||||
err := c.CheckValidDataType("GPU_IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType})
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -3,6 +3,8 @@ package indexparamcheck
|
|||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
// raftIVFPQChecker checks if a RAFT_IVF_PQ index can be built.
|
||||
|
@ -11,8 +13,8 @@ type raftIVFPQChecker struct {
|
|||
}
|
||||
|
||||
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
||||
func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
||||
func (c *raftIVFPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
if !CheckStrByValues(params, Metric, RaftMetrics) {
|
|
@ -7,6 +7,8 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -144,9 +146,14 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
|
|||
{p9, false},
|
||||
}
|
||||
|
||||
c := newRaftIVFPQChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ")
|
||||
if c == nil {
|
||||
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||
return
|
||||
}
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "GPU_IVF_PQ"
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -214,9 +221,13 @@ func Test_raftIVFPQChecker_CheckValidDataType(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
c := newRaftIVFPQChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ")
|
||||
if c == nil {
|
||||
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||
return
|
||||
}
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
||||
err := c.CheckValidDataType("GPU_IVF_PQ", &schemapb.FieldSchema{DataType: test.dType})
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -0,0 +1,11 @@
|
|||
package indexparamcheck
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
||||
type scalarIndexChecker struct {
|
||||
baseChecker
|
||||
}
|
||||
|
||||
func (c scalarIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
return nil
|
||||
}
|
|
@ -4,9 +4,11 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func TestCheckIndexValid(t *testing.T) {
|
||||
scalarIndexChecker := &scalarIndexChecker{}
|
||||
assert.NoError(t, scalarIndexChecker.CheckTrain(map[string]string{}))
|
||||
assert.NoError(t, scalarIndexChecker.CheckTrain(schemapb.DataType_Bool, map[string]string{}))
|
||||
}
|
|
@ -3,6 +3,8 @@ package indexparamcheck
|
|||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
// scaNNChecker checks if a SCANN index can be built.
|
||||
|
@ -11,8 +13,8 @@ type scaNNChecker struct {
|
|||
}
|
||||
|
||||
// CheckTrain checks if SCANN index can be built with the specific index parameters.
|
||||
func (c *scaNNChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
||||
func (c *scaNNChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
|
@ -87,9 +88,10 @@ func Test_scaNNChecker_CheckTrain(t *testing.T) {
|
|||
{p7, false},
|
||||
}
|
||||
|
||||
c := newScaNNChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("SCANN")
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
test.params[common.IndexTypeKey] = "SCANN"
|
||||
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
@ -159,7 +161,7 @@ func Test_scaNNChecker_CheckValidDataType(t *testing.T) {
|
|||
|
||||
c := newScaNNChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
||||
err := c.CheckValidDataType("SCANN", &schemapb.FieldSchema{DataType: test.dType})
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
|
@ -12,7 +12,7 @@ import (
|
|||
// sparse vector don't check for dim, but baseChecker does, thus not including baseChecker
|
||||
type sparseFloatVectorBaseChecker struct{}
|
||||
|
||||
func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) error {
|
||||
func (c sparseFloatVectorBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
if !CheckStrByValues(params, Metric, SparseMetrics) {
|
||||
return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics)
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) erro
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error {
|
||||
func (c sparseFloatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
dropRatioBuildStr, exist := params[SparseDropRatioBuild]
|
||||
if exist {
|
||||
dropRatioBuild, err := strconv.ParseFloat(dropRatioBuildStr, 64)
|
||||
|
@ -48,14 +48,14 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c sparseFloatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c sparseFloatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
if !typeutil.IsSparseFloatVectorType(field.GetDataType()) {
|
||||
return fmt.Errorf("only sparse float vector is supported for the specified index tpye")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
|
||||
func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||
setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType)
|
||||
}
|
||||
|
|
@ -6,84 +6,95 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
Metric: "IP",
|
||||
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||
Metric: "IP",
|
||||
}
|
||||
|
||||
invalidParams := map[string]string{
|
||||
Metric: "L2",
|
||||
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||
Metric: "L2",
|
||||
}
|
||||
|
||||
c := newSparseFloatVectorBaseChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX")
|
||||
|
||||
t.Run("valid metric", func(t *testing.T) {
|
||||
err := c.StaticCheck(validParams)
|
||||
err := c.StaticCheck(schemapb.DataType_SparseFloatVector, validParams)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid metric", func(t *testing.T) {
|
||||
err := c.StaticCheck(invalidParams)
|
||||
err := c.StaticCheck(schemapb.DataType_SparseFloatVector, invalidParams)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||
Metric: "IP",
|
||||
SparseDropRatioBuild: "0.5",
|
||||
BM25K1: "1.5",
|
||||
BM25B: "0.5",
|
||||
}
|
||||
|
||||
invalidDropRatio := map[string]string{
|
||||
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||
Metric: "IP",
|
||||
SparseDropRatioBuild: "1.5",
|
||||
}
|
||||
|
||||
invalidBM25K1 := map[string]string{
|
||||
BM25K1: "3.5",
|
||||
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||
Metric: "IP",
|
||||
BM25K1: "3.5",
|
||||
}
|
||||
|
||||
invalidBM25B := map[string]string{
|
||||
BM25B: "1.5",
|
||||
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||
Metric: "IP",
|
||||
BM25B: "1.5",
|
||||
}
|
||||
|
||||
c := newSparseFloatVectorBaseChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX")
|
||||
|
||||
t.Run("valid params", func(t *testing.T) {
|
||||
err := c.CheckTrain(validParams)
|
||||
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, validParams)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid drop ratio", func(t *testing.T) {
|
||||
err := c.CheckTrain(invalidDropRatio)
|
||||
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidDropRatio)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid BM25K1", func(t *testing.T) {
|
||||
err := c.CheckTrain(invalidBM25K1)
|
||||
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25K1)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid BM25B", func(t *testing.T) {
|
||||
err := c.CheckTrain(invalidBM25B)
|
||||
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25B)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_sparseFloatVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
||||
c := newSparseFloatVectorBaseChecker()
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX")
|
||||
|
||||
t.Run("valid data type", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{DataType: schemapb.DataType_SparseFloatVector}
|
||||
err := c.CheckValidDataType(field)
|
||||
err := c.CheckValidDataType("SPARSE_WAND", field)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid data type", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
|
||||
err := c.CheckValidDataType(field)
|
||||
err := c.CheckValidDataType("SPARSE_WAND", field)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
|
@ -12,11 +12,11 @@ type STLSORTChecker struct {
|
|||
scalarIndexChecker
|
||||
}
|
||||
|
||||
func (c *STLSORTChecker) CheckTrain(params map[string]string) error {
|
||||
return c.scalarIndexChecker.CheckTrain(params)
|
||||
func (c *STLSORTChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func (c *STLSORTChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c *STLSORTChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
if !typeutil.IsArithmetic(field.GetDataType()) {
|
||||
return fmt.Errorf("STL_SORT are only supported on numeric field")
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_STLSORTIndexChecker(t *testing.T) {
|
||||
c := newSTLSORTChecker()
|
||||
|
||||
assert.NoError(t, c.CheckTrain(schemapb.DataType_Int64, map[string]string{}))
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
}
|
|
@ -12,11 +12,11 @@ type TRIEChecker struct {
|
|||
scalarIndexChecker
|
||||
}
|
||||
|
||||
func (c *TRIEChecker) CheckTrain(params map[string]string) error {
|
||||
return c.scalarIndexChecker.CheckTrain(params)
|
||||
func (c *TRIEChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func (c *TRIEChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
||||
func (c *TRIEChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
if !typeutil.IsStringType(field.GetDataType()) {
|
||||
return fmt.Errorf("TRIE are only supported on varchar field")
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_TrieIndexChecker(t *testing.T) {
|
||||
c := newTRIEChecker()
|
||||
|
||||
assert.NoError(t, c.CheckTrain(schemapb.DataType_VarChar, map[string]string{}))
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
||||
assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
}
|
|
@ -20,7 +20,10 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
// CheckIntByRange check if the data corresponding to the key is in the range of [min, max].
|
||||
|
@ -69,3 +72,30 @@ func setDefaultIfNotExist(params map[string]string, key string, defaultValue str
|
|||
params[key] = defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
func CheckAutoIndexHelper(key string, m map[string]string, dtype schemapb.DataType) {
|
||||
indexType, ok := m[common.IndexTypeKey]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("%s invalid, index type not found", key))
|
||||
}
|
||||
|
||||
checker, err := GetIndexCheckerMgrInstance().GetChecker(indexType)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType))
|
||||
}
|
||||
|
||||
if err := checker.StaticCheck(dtype, m); err != nil {
|
||||
panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func CheckAutoIndexConfig() {
|
||||
autoIndexCfg := ¶mtable.Get().AutoIndexConfig
|
||||
CheckAutoIndexHelper(autoIndexCfg.IndexParams.Key, autoIndexCfg.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||
CheckAutoIndexHelper(autoIndexCfg.BinaryIndexParams.Key, autoIndexCfg.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
|
||||
CheckAutoIndexHelper(autoIndexCfg.SparseIndexParams.Key, autoIndexCfg.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector)
|
||||
}
|
||||
|
||||
func ValidateParamTable() {
|
||||
CheckAutoIndexConfig()
|
||||
}
|
|
@ -0,0 +1,269 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/config"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
func Test_CheckIntByRange(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"1": strconv.Itoa(1),
|
||||
"2": strconv.Itoa(2),
|
||||
"3": strconv.Itoa(3),
|
||||
"s1": "s1",
|
||||
"s2": "s2",
|
||||
"s3": "s3",
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
key string
|
||||
min int
|
||||
max int
|
||||
want bool
|
||||
}{
|
||||
{params, "1", 0, 4, true},
|
||||
{params, "2", 0, 4, true},
|
||||
{params, "3", 0, 4, true},
|
||||
{params, "1", 4, 5, false},
|
||||
{params, "2", 4, 5, false},
|
||||
{params, "3", 4, 5, false},
|
||||
{params, "4", 0, 4, false},
|
||||
{params, "5", 0, 4, false},
|
||||
{params, "6", 0, 4, false},
|
||||
{params, "s1", 0, 4, false},
|
||||
{params, "s2", 0, 4, false},
|
||||
{params, "s3", 0, 4, false},
|
||||
{params, "s4", 0, 4, false},
|
||||
{params, "s5", 0, 4, false},
|
||||
{params, "s6", 0, 4, false},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want {
|
||||
t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CheckStrByValues(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"1": strconv.Itoa(1),
|
||||
"2": strconv.Itoa(2),
|
||||
"3": strconv.Itoa(3),
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
key string
|
||||
container []string
|
||||
want bool
|
||||
}{
|
||||
{params, "1", []string{"1", "2", "3"}, true},
|
||||
{params, "2", []string{"1", "2", "3"}, true},
|
||||
{params, "3", []string{"1", "2", "3"}, true},
|
||||
{params, "1", []string{"4", "5", "6"}, false},
|
||||
{params, "2", []string{"4", "5", "6"}, false},
|
||||
{params, "3", []string{"4", "5", "6"}, false},
|
||||
{params, "1", []string{}, false},
|
||||
{params, "2", []string{}, false},
|
||||
{params, "3", []string{}, false},
|
||||
{params, "4", []string{"1", "2", "3"}, false},
|
||||
{params, "5", []string{"1", "2", "3"}, false},
|
||||
{params, "6", []string{"1", "2", "3"}, false},
|
||||
{params, "4", []string{"4", "5", "6"}, false},
|
||||
{params, "5", []string{"4", "5", "6"}, false},
|
||||
{params, "6", []string{"4", "5", "6"}, false},
|
||||
{params, "4", []string{}, false},
|
||||
{params, "5", []string{}, false},
|
||||
{params, "6", []string{}, false},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := CheckStrByValues(test.params, test.key, test.container); got != test.want {
|
||||
t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CheckAutoIndex(t *testing.T) {
|
||||
t.Run("index type not found", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"M": 30}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
IndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.Panics(t, func() {
|
||||
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("unsupported index type", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"index_type": "not supported"}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
IndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.Panics(t, func() {
|
||||
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("normal case, hnsw", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW"}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
IndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, metric.COSINE, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, binary vector", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.binary.build", `{"nlist": 1024, "index_type": "BIN_IVF_FLAT"}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
BinaryIndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.binary.build",
|
||||
},
|
||||
}
|
||||
p.BinaryIndexParams.Init(mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
CheckAutoIndexHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
|
||||
})
|
||||
metricType, exist := p.BinaryIndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, metric.HAMMING, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, sparse vector", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.sparse.build", `{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
SparseIndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.sparse.build",
|
||||
},
|
||||
}
|
||||
p.SparseIndexParams.Init(mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
CheckAutoIndexHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector)
|
||||
})
|
||||
metricType, exist := p.SparseIndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, metric.IP, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, ivf flat", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
IndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, metric.COSINE, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, ivf flat", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
IndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, metric.COSINE, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, diskann", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"index_type": "DISKANN"}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
IndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, metric.COSINE, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, bin flat", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"index_type": "BIN_FLAT"}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
IndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, metric.HAMMING, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, bin ivf flat", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "BIN_IVF_FLAT"}`)
|
||||
p := ¶mtable.AutoIndexConfig{
|
||||
IndexParams: paramtable.ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, metric.HAMMING, metricType)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package indexparamcheck
|
||||
|
||||
/*
|
||||
#cgo pkg-config: milvus_core
|
||||
|
||||
#include <stdlib.h> // free
|
||||
#include "segcore/vector_index_c.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexcgopb"
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type vecIndexChecker struct {
|
||||
baseChecker
|
||||
}
|
||||
|
||||
// HandleCStatus deals with the error returned from CGO
|
||||
func HandleCStatus(status *C.CStatus) error {
|
||||
if status.error_code == 0 {
|
||||
return nil
|
||||
}
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
|
||||
return fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
if typeutil.IsDenseFloatVectorType(dataType) {
|
||||
if !CheckStrByValues(params, Metric, FloatVectorMetrics) {
|
||||
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], FloatVectorMetrics)
|
||||
}
|
||||
} else if typeutil.IsSparseFloatVectorType(dataType) {
|
||||
if !CheckStrByValues(params, Metric, SparseMetrics) {
|
||||
return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics)
|
||||
}
|
||||
} else if typeutil.IsBinaryVectorType(dataType) {
|
||||
if !CheckStrByValues(params, Metric, BinaryVectorMetrics) {
|
||||
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinaryVectorMetrics)
|
||||
}
|
||||
}
|
||||
|
||||
indexType, exist := params[common.IndexTypeKey]
|
||||
|
||||
if !exist {
|
||||
return fmt.Errorf("no indexType is specified")
|
||||
}
|
||||
|
||||
if !vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) {
|
||||
return fmt.Errorf("indexType %s is not supported", indexType)
|
||||
}
|
||||
|
||||
protoIndexParams := &indexcgopb.IndexParams{
|
||||
Params: make([]*commonpb.KeyValuePair, 0),
|
||||
}
|
||||
|
||||
for key, value := range params {
|
||||
protoIndexParams.Params = append(protoIndexParams.Params, &commonpb.KeyValuePair{Key: key, Value: value})
|
||||
}
|
||||
|
||||
indexParamsBlob, err := proto.Marshal(protoIndexParams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal index params: %s", err)
|
||||
}
|
||||
|
||||
var status C.CStatus
|
||||
|
||||
cIndexType := C.CString(indexType)
|
||||
cDataType := uint32(dataType)
|
||||
status = C.ValidateIndexParams(cIndexType, cDataType, (*C.uint8_t)(unsafe.Pointer(&indexParamsBlob[0])), (C.uint64_t)(len(indexParamsBlob)))
|
||||
C.free(unsafe.Pointer(cIndexType))
|
||||
|
||||
return HandleCStatus(&status)
|
||||
}
|
||||
|
||||
func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
if err := c.StaticCheck(dataType, params); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.baseChecker.CheckTrain(dataType, params)
|
||||
}
|
||||
|
||||
func (c vecIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||
if !typeutil.IsVectorType(field.GetDataType()) {
|
||||
return fmt.Errorf("index %s only supports vector data type", indexType)
|
||||
}
|
||||
if !vecindexmgr.GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType()) {
|
||||
return fmt.Errorf("index %s do not support data type: %s", indexType, schemapb.DataType_name[int32(field.GetDataType())])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c vecIndexChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||
paramtable.SetDefaultMetricTypeIfNotExist(dType, params)
|
||||
}
|
||||
|
||||
func newVecIndexChecker() IndexChecker {
|
||||
return &vecIndexChecker{}
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func TestVecIndexChecker_StaticCheck(t *testing.T) {
|
||||
checker := newVecIndexChecker()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dataType schemapb.DataType
|
||||
params map[string]string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid IVF_FLAT index",
|
||||
dataType: schemapb.DataType_FloatVector,
|
||||
params: map[string]string{
|
||||
"index_type": "IVF_FLAT",
|
||||
"metric_type": "L2",
|
||||
"nlist": "1024",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid index type",
|
||||
dataType: schemapb.DataType_FloatVector,
|
||||
params: map[string]string{
|
||||
"index_type": "INVALID_INDEX",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Missing index type",
|
||||
dataType: schemapb.DataType_FloatVector,
|
||||
params: map[string]string{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checker.StaticCheck(tt.dataType, tt.params)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVecIndexChecker_CheckValidDataType(t *testing.T) {
|
||||
checker := newVecIndexChecker()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
indexType IndexType
|
||||
field *schemapb.FieldSchema
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid float vector",
|
||||
indexType: "IVF_FLAT",
|
||||
field: &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid data type",
|
||||
indexType: "IVF_FLAT",
|
||||
field: &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checker.CheckValidDataType(tt.indexType, tt.field)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVecIndexChecker_SetDefaultMetricTypeIfNotExist(t *testing.T) {
|
||||
checker := newVecIndexChecker()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dataType schemapb.DataType
|
||||
params map[string]string
|
||||
expectedType string
|
||||
}{
|
||||
{
|
||||
name: "Float vector",
|
||||
dataType: schemapb.DataType_FloatVector,
|
||||
params: map[string]string{},
|
||||
expectedType: FloatVectorDefaultMetricType,
|
||||
},
|
||||
{
|
||||
name: "Binary vector",
|
||||
dataType: schemapb.DataType_BinaryVector,
|
||||
params: map[string]string{},
|
||||
expectedType: BinaryVectorDefaultMetricType,
|
||||
},
|
||||
{
|
||||
name: "Existing metric type",
|
||||
dataType: schemapb.DataType_FloatVector,
|
||||
params: map[string]string{"metric_type": "IP"},
|
||||
expectedType: "IP",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
checker.SetDefaultMetricTypeIfNotExist(tt.dataType, tt.params)
|
||||
assert.Equal(t, tt.expectedType, tt.params["metric_type"])
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
package indexparamcheck
|
||||
|
||||
type binFlatChecker struct {
|
||||
binaryVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c binFlatChecker) CheckTrain(params map[string]string) error {
|
||||
return c.binaryVectorBaseChecker.CheckTrain(params)
|
||||
}
|
||||
|
||||
func (c binFlatChecker) StaticCheck(params map[string]string) error {
|
||||
return c.staticCheck(params)
|
||||
}
|
||||
|
||||
func newBinFlatChecker() IndexChecker {
|
||||
return &binFlatChecker{}
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_BitmapIndexChecker(t *testing.T) {
|
||||
c := newBITMAPChecker()
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true}))
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_HybridIndexChecker(t *testing.T) {
|
||||
c := newHYBRIDChecker()
|
||||
|
||||
assert.NoError(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "100"}))
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
|
||||
assert.Error(t, c.CheckTrain(map[string]string{}))
|
||||
assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "0"}))
|
||||
assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "2000"}))
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_INVERTEDIndexChecker(t *testing.T) {
|
||||
c := newINVERTEDChecker()
|
||||
|
||||
assert.NoError(t, c.CheckTrain(map[string]string{}))
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}))
|
||||
}
|
|
@ -1,26 +0,0 @@
|
|||
package indexparamcheck
|
||||
|
||||
type ivfBaseChecker struct {
|
||||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c ivfBaseChecker) StaticCheck(params map[string]string) error {
|
||||
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
||||
return errOutOfRange(NLIST, MinNList, MaxNList)
|
||||
}
|
||||
|
||||
// skip check number of rows
|
||||
|
||||
return c.floatVectorBaseChecker.staticCheck(params)
|
||||
}
|
||||
|
||||
func (c ivfBaseChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.StaticCheck(params); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.floatVectorBaseChecker.CheckTrain(params)
|
||||
}
|
||||
|
||||
func newIVFBaseChecker() IndexChecker {
|
||||
return &ivfBaseChecker{}
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
package indexparamcheck
|
||||
|
||||
type scalarIndexChecker struct {
|
||||
baseChecker
|
||||
}
|
||||
|
||||
func (c scalarIndexChecker) CheckTrain(params map[string]string) error {
|
||||
return nil
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_STLSORTIndexChecker(t *testing.T) {
|
||||
c := newSTLSORTChecker()
|
||||
|
||||
assert.NoError(t, c.CheckTrain(map[string]string{}))
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_TrieIndexChecker(t *testing.T) {
|
||||
c := newTRIEChecker()
|
||||
|
||||
assert.NoError(t, c.CheckTrain(map[string]string{}))
|
||||
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||
}
|
|
@ -1,87 +0,0 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_CheckIntByRange(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"1": strconv.Itoa(1),
|
||||
"2": strconv.Itoa(2),
|
||||
"3": strconv.Itoa(3),
|
||||
"s1": "s1",
|
||||
"s2": "s2",
|
||||
"s3": "s3",
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
key string
|
||||
min int
|
||||
max int
|
||||
want bool
|
||||
}{
|
||||
{params, "1", 0, 4, true},
|
||||
{params, "2", 0, 4, true},
|
||||
{params, "3", 0, 4, true},
|
||||
{params, "1", 4, 5, false},
|
||||
{params, "2", 4, 5, false},
|
||||
{params, "3", 4, 5, false},
|
||||
{params, "4", 0, 4, false},
|
||||
{params, "5", 0, 4, false},
|
||||
{params, "6", 0, 4, false},
|
||||
{params, "s1", 0, 4, false},
|
||||
{params, "s2", 0, 4, false},
|
||||
{params, "s3", 0, 4, false},
|
||||
{params, "s4", 0, 4, false},
|
||||
{params, "s5", 0, 4, false},
|
||||
{params, "s6", 0, 4, false},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want {
|
||||
t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CheckStrByValues(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"1": strconv.Itoa(1),
|
||||
"2": strconv.Itoa(2),
|
||||
"3": strconv.Itoa(3),
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
key string
|
||||
container []string
|
||||
want bool
|
||||
}{
|
||||
{params, "1", []string{"1", "2", "3"}, true},
|
||||
{params, "2", []string{"1", "2", "3"}, true},
|
||||
{params, "3", []string{"1", "2", "3"}, true},
|
||||
{params, "1", []string{"4", "5", "6"}, false},
|
||||
{params, "2", []string{"4", "5", "6"}, false},
|
||||
{params, "3", []string{"4", "5", "6"}, false},
|
||||
{params, "1", []string{}, false},
|
||||
{params, "2", []string{}, false},
|
||||
{params, "3", []string{}, false},
|
||||
{params, "4", []string{"1", "2", "3"}, false},
|
||||
{params, "5", []string{"1", "2", "3"}, false},
|
||||
{params, "6", []string{"1", "2", "3"}, false},
|
||||
{params, "4", []string{"4", "5", "6"}, false},
|
||||
{params, "5", []string{"4", "5", "6"}, false},
|
||||
{params, "6", []string{"4", "5", "6"}, false},
|
||||
{params, "4", []string{}, false},
|
||||
{params, "5", []string{}, false},
|
||||
{params, "6", []string{}, false},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := CheckStrByValues(test.params, test.key, test.container); got != test.want {
|
||||
t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -24,12 +24,13 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/config"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
// /////////////////////////////////////////////////////////////////////////////
|
||||
// --- common ---
|
||||
type autoIndexConfig struct {
|
||||
type AutoIndexConfig struct {
|
||||
Enable ParamItem `refreshable:"true"`
|
||||
EnableOptimize ParamItem `refreshable:"true"`
|
||||
EnableResultLimitCheck ParamItem `refreshable:"true"`
|
||||
|
@ -60,7 +61,7 @@ const (
|
|||
DefaultBitmapCardinalityLimit = 100
|
||||
)
|
||||
|
||||
func (p *autoIndexConfig) init(base *BaseTable) {
|
||||
func (p *AutoIndexConfig) init(base *BaseTable) {
|
||||
p.Enable = ParamItem{
|
||||
Key: "autoIndex.enable",
|
||||
Version: "2.2.0",
|
||||
|
@ -157,7 +158,7 @@ func (p *autoIndexConfig) init(base *BaseTable) {
|
|||
}
|
||||
p.AutoIndexTuningConfig.Init(base.mgr)
|
||||
|
||||
p.panicIfNotValidAndSetDefaultMetricType(base.mgr)
|
||||
p.SetDefaultMetricType(base.mgr)
|
||||
|
||||
p.ScalarAutoIndexEnable = ParamItem{
|
||||
Key: "scalarAutoIndex.enable",
|
||||
|
@ -244,37 +245,47 @@ func (p *autoIndexConfig) init(base *BaseTable) {
|
|||
p.ScalarBoolIndexType.Init(base.mgr)
|
||||
}
|
||||
|
||||
func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricType(mgr *config.Manager) {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
|
||||
// SetDefaultMetricType The config check logic has been moved to internal package; only set defulat metric here
|
||||
func (p *AutoIndexConfig) SetDefaultMetricType(mgr *config.Manager) {
|
||||
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
|
||||
}
|
||||
|
||||
func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricTypeHelper(key string, m map[string]string, dtype schemapb.DataType, mgr *config.Manager) {
|
||||
func setDefaultIfNotExist(params map[string]string, key string, defaultValue string) {
|
||||
_, exist := params[key]
|
||||
if !exist {
|
||||
params[key] = defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
FloatVectorDefaultMetricType = metric.COSINE
|
||||
SparseFloatVectorDefaultMetricType = metric.IP
|
||||
BinaryVectorDefaultMetricType = metric.HAMMING
|
||||
)
|
||||
|
||||
func SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||
if typeutil.IsDenseFloatVectorType(dType) {
|
||||
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
|
||||
} else if typeutil.IsSparseFloatVectorType(dType) {
|
||||
setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType)
|
||||
} else if typeutil.IsBinaryVectorType(dType) {
|
||||
setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AutoIndexConfig) SetDefaultMetricTypeHelper(key string, m map[string]string, dtype schemapb.DataType, mgr *config.Manager) {
|
||||
if m == nil {
|
||||
panic(fmt.Sprintf("%s invalid, should be json format", key))
|
||||
}
|
||||
|
||||
indexType, ok := m[common.IndexTypeKey]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("%s invalid, index type not found", key))
|
||||
}
|
||||
|
||||
checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType))
|
||||
}
|
||||
|
||||
checker.SetDefaultMetricTypeIfNotExist(m, dtype)
|
||||
|
||||
if err := checker.StaticCheck(m); err != nil {
|
||||
panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error()))
|
||||
}
|
||||
SetDefaultMetricTypeIfNotExist(dtype, m)
|
||||
|
||||
p.reset(key, m, mgr)
|
||||
}
|
||||
|
||||
func (p *autoIndexConfig) reset(key string, m map[string]string, mgr *config.Manager) {
|
||||
func (p *AutoIndexConfig) reset(key string, m map[string]string, mgr *config.Manager) {
|
||||
ret, err := funcutil.MapToJSON(m)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("%s: convert to json failed, parameters invalid, error: %s", key, err.Error()))
|
||||
|
|
|
@ -26,7 +26,6 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/config"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -134,180 +133,16 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) {
|
|||
t.Run("not in json format", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", "not in json format")
|
||||
p := &autoIndexConfig{
|
||||
p := &AutoIndexConfig{
|
||||
IndexParams: ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.Panics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("index type not found", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"M": 30}`)
|
||||
p := &autoIndexConfig{
|
||||
IndexParams: ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.Panics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("unsupported index type", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"index_type": "not supported"}`)
|
||||
p := &autoIndexConfig{
|
||||
IndexParams: ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.Panics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("normal case, hnsw", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW"}`)
|
||||
p := &autoIndexConfig{
|
||||
IndexParams: ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, binary vector", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.binary.build", `{"nlist": 1024, "index_type": "BIN_IVF_FLAT"}`)
|
||||
p := &autoIndexConfig{
|
||||
BinaryIndexParams: ParamItem{
|
||||
Key: "autoIndex.params.binary.build",
|
||||
},
|
||||
}
|
||||
p.BinaryIndexParams.Init(mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||
})
|
||||
metricType, exist := p.BinaryIndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, sparse vector", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.sparse.build", `{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}`)
|
||||
p := &autoIndexConfig{
|
||||
SparseIndexParams: ParamItem{
|
||||
Key: "autoIndex.params.sparse.build",
|
||||
},
|
||||
}
|
||||
p.SparseIndexParams.Init(mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
|
||||
})
|
||||
metricType, exist := p.SparseIndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, indexparamcheck.SparseFloatVectorDefaultMetricType, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, ivf flat", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
|
||||
p := &autoIndexConfig{
|
||||
IndexParams: ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, ivf flat", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
|
||||
p := &autoIndexConfig{
|
||||
IndexParams: ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, diskann", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"index_type": "DISKANN"}`)
|
||||
p := &autoIndexConfig{
|
||||
IndexParams: ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, bin flat", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"index_type": "BIN_FLAT"}`)
|
||||
p := &autoIndexConfig{
|
||||
IndexParams: ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType)
|
||||
})
|
||||
|
||||
t.Run("normal case, bin ivf flat", func(t *testing.T) {
|
||||
mgr := config.NewManager()
|
||||
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "BIN_IVF_FLAT"}`)
|
||||
p := &autoIndexConfig{
|
||||
IndexParams: ParamItem{
|
||||
Key: "autoIndex.params.build",
|
||||
},
|
||||
}
|
||||
p.IndexParams.Init(mgr)
|
||||
assert.NotPanics(t, func() {
|
||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||
})
|
||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType)
|
||||
})
|
||||
}
|
||||
|
||||
func TestScalarAutoIndexParams_build(t *testing.T) {
|
||||
|
|
|
@ -65,7 +65,7 @@ type ComponentParam struct {
|
|||
|
||||
CommonCfg commonConfig
|
||||
QuotaConfig quotaConfig
|
||||
AutoIndexConfig autoIndexConfig
|
||||
AutoIndexConfig AutoIndexConfig
|
||||
GpuConfig gpuConfig
|
||||
TraceCfg traceConfig
|
||||
|
||||
|
|
|
@ -607,7 +607,7 @@ func TestCreateIndexJsonField(t *testing.T) {
|
|||
// create vector index on json field
|
||||
idx := index.NewSCANNIndex(entity.L2, 8, false)
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultJSONFieldName, idx).WithIndexName("json_index"))
|
||||
common.CheckErr(t, err, false, "data type should be FloatVector, Float16Vector or BFloat16Vector")
|
||||
common.CheckErr(t, err, false, "index SCANN only supports vector data type")
|
||||
|
||||
// create scalar index on json field
|
||||
type scalarIndexError struct {
|
||||
|
@ -653,7 +653,7 @@ func TestCreateUnsupportedIndexArrayField(t *testing.T) {
|
|||
if field.DataType == entity.FieldTypeArray {
|
||||
// create vector index
|
||||
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, vectorIdx).WithIndexName("vector_index"))
|
||||
common.CheckErr(t, err1, false, "data type should be FloatVector, Float16Vector or BFloat16Vector")
|
||||
common.CheckErr(t, err1, false, "index SCANN only supports vector data type")
|
||||
|
||||
// create scalar index
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxErr.idx))
|
||||
|
@ -840,11 +840,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) {
|
|||
for _, drb := range []float64{-0.3, 1.3} {
|
||||
idxInverted := index.NewSparseInvertedIndex(entity.IP, drb)
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted))
|
||||
common.CheckErr(t, err, false, "must be in range [0, 1)")
|
||||
common.CheckErr(t, err, false, "Out of range in json: param 'drop_ratio_build'")
|
||||
|
||||
idxWand := index.NewSparseWANDIndex(entity.IP, drb)
|
||||
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand))
|
||||
common.CheckErr(t, err1, false, "must be in range [0, 1)")
|
||||
common.CheckErr(t, err1, false, "Out of range in json: param 'drop_ratio_build'")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -944,20 +944,22 @@ func TestCreateVectorIndexScalarField(t *testing.T) {
|
|||
// create float vector index on scalar field
|
||||
for _, idx := range hp.GenAllFloatIndex(entity.COSINE) {
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idx))
|
||||
common.CheckErr(t, err, false, "can't build hnsw in not vector type",
|
||||
"data type should be FloatVector, Float16Vector or BFloat16Vector")
|
||||
expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idx.IndexType())
|
||||
common.CheckErr(t, err, false, expErrorMsg)
|
||||
}
|
||||
|
||||
// create binary vector index on scalar field
|
||||
for _, idxBinary := range []index.Index{index.NewBinFlatIndex(entity.IP), index.NewBinIvfFlatIndex(entity.COSINE, 64)} {
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxBinary))
|
||||
common.CheckErr(t, err, false, "binary vector is only supported")
|
||||
expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idxBinary.IndexType())
|
||||
common.CheckErr(t, err, false, expErrorMsg)
|
||||
}
|
||||
|
||||
// create sparse vector index on scalar field
|
||||
for _, idxSparse := range []index.Index{index.NewSparseInvertedIndex(entity.IP, 0.2), index.NewSparseWANDIndex(entity.IP, 0.3)} {
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxSparse))
|
||||
common.CheckErr(t, err, false, "only sparse float vector is supported for the specified index")
|
||||
expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idxSparse.IndexType())
|
||||
common.CheckErr(t, err, false, expErrorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -972,7 +974,7 @@ func TestCreateIndexInvalidParams(t *testing.T) {
|
|||
_, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
|
||||
|
||||
// invalid IvfFlat nlist [1, 65536]
|
||||
errMsg := "nlist out of range: [1, 65536]"
|
||||
errMsg := "Out of range in json: param 'nlist'"
|
||||
for _, invalidNlist := range []int{0, -1, 65536 + 1} {
|
||||
// IvfFlat
|
||||
idxIvfFlat := index.NewIvfFlatIndex(entity.L2, invalidNlist)
|
||||
|
@ -997,7 +999,7 @@ func TestCreateIndexInvalidParams(t *testing.T) {
|
|||
// IvfFlat
|
||||
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 8, invalidNBits)
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxIvfPq))
|
||||
common.CheckErr(t, err, false, "parameter `nbits` out of range, expect range [1,64]")
|
||||
common.CheckErr(t, err, false, "Out of range in json: param 'nbits'")
|
||||
}
|
||||
|
||||
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 7, 8)
|
||||
|
@ -1009,13 +1011,13 @@ func TestCreateIndexInvalidParams(t *testing.T) {
|
|||
// IvfFlat
|
||||
idxHnsw := index.NewHNSWIndex(entity.L2, invalidM, 96)
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw))
|
||||
common.CheckErr(t, err, false, "M out of range: [1, 2048]")
|
||||
common.CheckErr(t, err, false, "Out of range in json: param 'M'")
|
||||
}
|
||||
for _, invalidEfConstruction := range []int{0, 2147483647 + 1} {
|
||||
// IvfFlat
|
||||
idxHnsw := index.NewHNSWIndex(entity.L2, 8, invalidEfConstruction)
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw))
|
||||
common.CheckErr(t, err, false, "efConstruction out of range: [1, 2147483647]")
|
||||
common.CheckErr(t, err, false, "Out of range in json: param 'efConstruction'", "integer value out of range, key: 'efConstruction'")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -33,10 +33,10 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/importutilv2"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
|
@ -66,7 +66,7 @@ func (s *BulkInsertSuite) SetupTest() {
|
|||
s.autoID = false
|
||||
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.indexType = indexparamcheck.IndexHNSW
|
||||
s.indexType = "HNSW"
|
||||
s.metricType = metric.L2
|
||||
}
|
||||
|
||||
|
@ -225,29 +225,29 @@ func (s *BulkInsertSuite) TestMultiFileTypes() {
|
|||
s.fileType = fileType
|
||||
|
||||
s.vecType = schemapb.DataType_BinaryVector
|
||||
s.indexType = indexparamcheck.IndexFaissBinIvfFlat
|
||||
s.indexType = "BIN_IVF_FLAT"
|
||||
s.metricType = metric.HAMMING
|
||||
s.run()
|
||||
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.indexType = indexparamcheck.IndexHNSW
|
||||
s.indexType = "HNSW"
|
||||
s.metricType = metric.L2
|
||||
s.run()
|
||||
|
||||
s.vecType = schemapb.DataType_Float16Vector
|
||||
s.indexType = indexparamcheck.IndexHNSW
|
||||
s.indexType = "HNSW"
|
||||
s.metricType = metric.L2
|
||||
s.run()
|
||||
|
||||
s.vecType = schemapb.DataType_BFloat16Vector
|
||||
s.indexType = indexparamcheck.IndexHNSW
|
||||
s.indexType = "HNSW"
|
||||
s.metricType = metric.L2
|
||||
s.run()
|
||||
|
||||
// TODO: not support numpy for SparseFloatVector by now
|
||||
if fileType != importutilv2.Numpy {
|
||||
s.vecType = schemapb.DataType_SparseFloatVector
|
||||
s.indexType = indexparamcheck.IndexSparseWand
|
||||
s.indexType = "SPARSE_WAND"
|
||||
s.metricType = metric.IP
|
||||
s.run()
|
||||
}
|
||||
|
|
|
@ -26,23 +26,22 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
)
|
||||
|
||||
const (
|
||||
IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat
|
||||
IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ
|
||||
IndexFaissIDMap = indexparamcheck.IndexFaissIDMap
|
||||
IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat
|
||||
IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ
|
||||
IndexScaNN = indexparamcheck.IndexScaNN
|
||||
IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8
|
||||
IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap
|
||||
IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat
|
||||
IndexHNSW = indexparamcheck.IndexHNSW
|
||||
IndexDISKANN = indexparamcheck.IndexDISKANN
|
||||
IndexSparseInvertedIndex = indexparamcheck.IndexSparseInverted
|
||||
IndexSparseWand = indexparamcheck.IndexSparseWand
|
||||
IndexRaftIvfFlat = "GPU_IVF_FLAT"
|
||||
IndexRaftIvfPQ = "GPU_IVF_PQ"
|
||||
IndexFaissIDMap = "FLAT"
|
||||
IndexFaissIvfFlat = "IVF_FLAT"
|
||||
IndexFaissIvfPQ = "IVF_PQ"
|
||||
IndexScaNN = "SCANN"
|
||||
IndexFaissIvfSQ8 = "IVF_SQ8"
|
||||
IndexFaissBinIDMap = "BIN_FLAT"
|
||||
IndexFaissBinIvfFlat = "BIN_IVF_FLAT"
|
||||
IndexHNSW = "HNSW"
|
||||
IndexDISKANN = "DISKANN"
|
||||
IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX"
|
||||
IndexSparseWand = "SPARSE_WAND"
|
||||
)
|
||||
|
||||
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
|
||||
|
|
|
@ -2103,6 +2103,8 @@ def gen_simple_index():
|
|||
continue
|
||||
elif ct.all_index_types[i] in ct.sparse_support:
|
||||
continue
|
||||
elif ct.all_index_types[i] in ct.gpu_support:
|
||||
continue
|
||||
dic = {"index_type": ct.all_index_types[i], "metric_type": "L2"}
|
||||
dic.update({"params": ct.default_all_indexes_params[i]})
|
||||
index_params.append(dic)
|
||||
|
|
|
@ -244,6 +244,7 @@ default_all_search_params_params = [{}, {"nprobe": 32}, {"nprobe": 32}, {"nprobe
|
|||
Handler_type = ["GRPC", "HTTP"]
|
||||
binary_support = ["BIN_FLAT", "BIN_IVF_FLAT"]
|
||||
sparse_support = ["SPARSE_INVERTED_INDEX", "SPARSE_WAND"]
|
||||
gpu_support = ["GPU_IVF_FLAT", "GPU_IVF_PQ"]
|
||||
default_L0_metric = "COSINE"
|
||||
float_metrics = ["L2", "IP", "COSINE"]
|
||||
binary_metrics = ["JACCARD", "HAMMING", "SUBSTRUCTURE", "SUPERSTRUCTURE"]
|
||||
|
|
|
@ -57,6 +57,8 @@ default_index_params = [
|
|||
def create_target_index(index, field_name):
|
||||
index["field_name"] = field_name
|
||||
|
||||
def gpu_support():
|
||||
return ["GPU_IVF_FLAT", "GPU_IVF_PQ"]
|
||||
|
||||
def binary_support():
|
||||
return ["BIN_FLAT", "BIN_IVF_FLAT"]
|
||||
|
@ -764,6 +766,8 @@ def gen_simple_index():
|
|||
for i in range(len(all_index_types)):
|
||||
if all_index_types[i] in binary_support():
|
||||
continue
|
||||
if all_index_types[i] in gpu_support():
|
||||
continue
|
||||
dic = {"index_type": all_index_types[i], "metric_type": "L2"}
|
||||
dic.update({"params": default_index_params[i]})
|
||||
index_params.append(dic)
|
||||
|
|
Loading…
Reference in New Issue