Fix unsupported index combinations

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/4973/head^2
dragondriver 2021-01-08 14:10:10 +08:00 committed by yefu.chen
parent d17442d5be
commit 92b2df14ca
5 changed files with 50 additions and 0 deletions

View File

@ -34,6 +34,9 @@ IndexWrapper::IndexWrapper(const char* serialized_type_params, const char* seria
auto mode = get_config_by_name<std::string>("index_mode");
auto index_mode = mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU;
auto index_type = get_index_type();
auto metric_type = get_metric_type();
AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type);
index_ = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(get_index_type(), index_mode);
Assert(index_ != nullptr);
}
@ -263,6 +266,21 @@ IndexWrapper::get_index_type() {
return type.has_value() ? type.value() : knowhere::IndexEnum::INDEX_FAISS_IVFPQ;
}
std::string
IndexWrapper::get_metric_type() {
auto type = get_config_by_name<std::string>(knowhere::Metric::TYPE);
if (type.has_value()) {
return type.value();
} else {
auto index_type = get_index_type();
if (is_in_bin_list(index_type)) {
return knowhere::Metric::JACCARD;
} else {
return knowhere::Metric::L2;
}
}
}
std::unique_ptr<IndexWrapper::QueryResult>
IndexWrapper::Query(const knowhere::DatasetPtr& dataset) {
return std::move(QueryImpl(dataset, config_));

View File

@ -59,6 +59,9 @@ class IndexWrapper {
std::string
get_index_type();
std::string
get_metric_type();
template <typename T>
std::optional<T>
get_config_by_name(std::string name);

View File

@ -14,6 +14,7 @@
#include <vector>
#include <string>
#include <algorithm>
#include <tuple>
#include "index/knowhere/knowhere/index/IndexType.h"
@ -57,6 +58,14 @@ Need_BuildAll_list() {
return ret;
}
std::vector<std::tuple<std::string, std::string>>
unsupported_index_combinations() {
static std::vector<std::tuple<std::string, std::string>> ret{
std::make_tuple(std::string(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT), std::string(knowhere::Metric::L2)),
};
return ret;
}
template <typename T>
bool
is_in_list(const T& t, std::function<std::vector<T>()> list_func) {
@ -84,5 +93,11 @@ is_in_need_id_list(const milvus::knowhere::IndexType& index_type) {
return is_in_list<std::string>(index_type, Need_ID_List);
}
bool
is_unsupported(const milvus::knowhere::IndexType& index_type, const milvus::knowhere::MetricType& metric_type) {
return is_in_list<std::tuple<std::string, std::string>>(std::make_tuple(index_type, metric_type),
unsupported_index_combinations);
}
} // namespace indexbuilder
} // namespace milvus

View File

@ -14,6 +14,7 @@ package indexbuilder
import "C"
import (
"errors"
"fmt"
"strconv"
"unsafe"
@ -142,6 +143,8 @@ func (index *CIndex) Delete() error {
}
func NewCIndex(typeParams, indexParams map[string]string) (Index, error) {
fmt.Println("NNNNNNNNNNNNNNNNNNNNNNNNNNN typeParams: ", typeParams)
fmt.Println("NNNNNNNNNNNNNNNNNNNNNNNNNNN indexParams: ", indexParams)
protoTypeParams := &indexcgopb.TypeParams{
Params: make([]*commonpb.KeyValuePair, 0),
}
@ -168,10 +171,14 @@ func NewCIndex(typeParams, indexParams map[string]string) (Index, error) {
CIndex* res_index);
*/
var indexPtr C.CIndex
fmt.Println("before create index ........................................")
status := C.CreateIndex(typeParamsPointer, indexParamsPointer, &indexPtr)
fmt.Println("after create index ........................................")
errorCode := status.error_code
fmt.Println("EEEEEEEEEEEEEEEEEEEEEEEEEE error code: ", errorCode)
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
fmt.Println("EEEEEEEEEEEEEEEEEEEEEEEEEE error msg: ", errorMsg)
defer C.free(unsafe.Pointer(status.error_msg))
return nil, errors.New(" failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}

View File

@ -2,6 +2,7 @@ package indexbuilder
import (
"context"
"fmt"
"log"
"strconv"
"time"
@ -171,10 +172,12 @@ func (it *IndexBuildTask) Execute() error {
indexParams[key] = value
}
fmt.Println("before NewCIndex ..........................")
it.index, err = NewCIndex(typeParams, indexParams)
if err != nil {
return err
}
fmt.Println("after NewCIndex ..........................")
getKeyByPathNaive := func(path string) string {
// splitElements := strings.Split(path, "/")
@ -223,6 +226,7 @@ func (it *IndexBuildTask) Execute() error {
for _, value := range insertData.Data {
// TODO: BinaryVectorFieldData
fmt.Println("before build index ..................................")
floatVectorFieldData, fOk := value.(*storage.FloatVectorFieldData)
if fOk {
err = it.index.BuildFloatVecIndexWithoutIds(floatVectorFieldData.Data)
@ -238,12 +242,15 @@ func (it *IndexBuildTask) Execute() error {
return err
}
}
fmt.Println("after build index ..................................")
if !fOk && !bOk {
return errors.New("we expect FloatVectorFieldData or BinaryVectorFieldData")
}
fmt.Println("before serialize .............................................")
indexBlobs, err := it.index.Serialize()
fmt.Println("after serialize .............................................")
if err != nil {
return err
}