Avoid crash when index parameters are invalid

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/4973/head^2
dragondriver 2021-01-08 17:45:57 +08:00 committed by yefu.chen
parent 65089ea362
commit 018466a256
5 changed files with 41 additions and 17 deletions

View File

@ -19,6 +19,7 @@
#include "utils/EasyAssert.h"
#include "IndexWrapper.h"
#include "indexbuilder/utils.h"
#include "index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h"
namespace milvus {
namespace indexbuilder {
@ -29,14 +30,11 @@ IndexWrapper::IndexWrapper(const char* serialized_type_params, const char* seria
parse();
std::map<std::string, knowhere::IndexMode> mode_map = {{"CPU", knowhere::IndexMode::MODE_CPU},
{"GPU", knowhere::IndexMode::MODE_GPU}};
auto mode = get_config_by_name<std::string>("index_mode");
auto index_mode = mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU;
auto index_mode = get_index_mode();
auto index_type = get_index_type();
auto metric_type = get_metric_type();
AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type);
index_ = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(get_index_type(), index_mode);
Assert(index_ != nullptr);
}
@ -157,6 +155,11 @@ IndexWrapper::dim() {
void
IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) {
auto index_type = get_index_type();
auto index_mode = get_index_mode();
config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS);
auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type);
AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!");
if (is_in_need_id_list(index_type)) {
PanicInfo(std::string(index_type) + " doesn't support build without ids yet!");
}
@ -176,6 +179,11 @@ IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) {
void
IndexWrapper::BuildWithIds(const knowhere::DatasetPtr& dataset) {
Assert(dataset->data().find(milvus::knowhere::meta::IDS) != dataset->data().end());
auto index_type = get_index_type();
auto index_mode = get_index_mode();
config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS);
auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type);
AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!");
// index_->Train(dataset, config_);
// index_->Add(dataset, config_);
index_->BuildAll(dataset, config_);
@ -281,6 +289,16 @@ IndexWrapper::get_metric_type() {
}
}
knowhere::IndexMode
IndexWrapper::get_index_mode() {
static std::map<std::string, knowhere::IndexMode> mode_map = {
{"CPU", knowhere::IndexMode::MODE_CPU},
{"GPU", knowhere::IndexMode::MODE_GPU},
};
auto mode = get_config_by_name<std::string>("index_mode");
return mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU;
}
std::unique_ptr<IndexWrapper::QueryResult>
IndexWrapper::Query(const knowhere::DatasetPtr& dataset) {
return std::move(QueryImpl(dataset, config_));

View File

@ -62,6 +62,9 @@ class IndexWrapper {
std::string
get_metric_type();
knowhere::IndexMode
get_index_mode();
template <typename T>
std::optional<T>
get_config_by_name(std::string name);

View File

@ -35,7 +35,7 @@ CreateIndex(const char* serialized_type_params, const char* serialized_index_par
*res_index = index.release();
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -59,7 +59,7 @@ BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float*
cIndex->BuildWithoutIds(ds);
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -77,7 +77,7 @@ BuildBinaryVecIndexWithoutIds(CIndex index, int64_t data_size, const uint8_t* ve
cIndex->BuildWithoutIds(ds);
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -94,7 +94,7 @@ SerializeToSlicedBuffer(CIndex index, int32_t* buffer_size, char** res_buffer) {
*res_buffer = binary.data;
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -109,7 +109,7 @@ LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, in
cIndex->Load(serialized_sliced_blob_buffer, size);
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -129,7 +129,7 @@ QueryOnFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -153,7 +153,7 @@ QueryOnFloatVecIndexWithParam(CIndex index,
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -173,7 +173,7 @@ QueryOnBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors, C
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -197,7 +197,7 @@ QueryOnBinaryVecIndexWithParam(CIndex index,
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -213,7 +213,7 @@ CreateQueryResult(CIndexQueryResult* res) {
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
@ -259,7 +259,7 @@ DeleteQueryResult(CIndexQueryResult res) {
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
} catch (std::exception& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}

View File

@ -106,10 +106,13 @@ func (index *CIndex) BuildFloatVecIndexWithoutIds(vectors []float32) error {
CStatus
BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float* vectors);
*/
fmt.Println("before BuildFloatVecIndexWithoutIds")
status := C.BuildFloatVecIndexWithoutIds(index.indexPtr, (C.int64_t)(len(vectors)), (*C.float)(&vectors[0]))
errorCode := status.error_code
fmt.Println("BuildFloatVecIndexWithoutIds error code: ", errorCode)
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
fmt.Println("BuildFloatVecIndexWithoutIds error msg: ", errorMsg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New("BuildFloatVecIndexWithoutIds failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}

View File

@ -440,7 +440,7 @@ func (qt *QueryTask) PostExecute() error {
hits := make([][]*servicepb.Hits, 0)
for _, partialSearchResult := range filterSearchResult {
if len(partialSearchResult.Hits) <= 0 {
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
filterReason += "nq is zero\n"
continue
}