mirror of https://github.com/milvus-io/milvus.git
Avoid crash when index parameters are invalid
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/4973/head^2
parent
65089ea362
commit
018466a256
|
@ -19,6 +19,7 @@
|
|||
#include "utils/EasyAssert.h"
|
||||
#include "IndexWrapper.h"
|
||||
#include "indexbuilder/utils.h"
|
||||
#include "index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace indexbuilder {
|
||||
|
@ -29,14 +30,11 @@ IndexWrapper::IndexWrapper(const char* serialized_type_params, const char* seria
|
|||
|
||||
parse();
|
||||
|
||||
std::map<std::string, knowhere::IndexMode> mode_map = {{"CPU", knowhere::IndexMode::MODE_CPU},
|
||||
{"GPU", knowhere::IndexMode::MODE_GPU}};
|
||||
auto mode = get_config_by_name<std::string>("index_mode");
|
||||
auto index_mode = mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU;
|
||||
|
||||
auto index_mode = get_index_mode();
|
||||
auto index_type = get_index_type();
|
||||
auto metric_type = get_metric_type();
|
||||
AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type);
|
||||
|
||||
index_ = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(get_index_type(), index_mode);
|
||||
Assert(index_ != nullptr);
|
||||
}
|
||||
|
@ -157,6 +155,11 @@ IndexWrapper::dim() {
|
|||
void
|
||||
IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) {
|
||||
auto index_type = get_index_type();
|
||||
auto index_mode = get_index_mode();
|
||||
config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type);
|
||||
AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!");
|
||||
|
||||
if (is_in_need_id_list(index_type)) {
|
||||
PanicInfo(std::string(index_type) + " doesn't support build without ids yet!");
|
||||
}
|
||||
|
@ -176,6 +179,11 @@ IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) {
|
|||
void
|
||||
IndexWrapper::BuildWithIds(const knowhere::DatasetPtr& dataset) {
|
||||
Assert(dataset->data().find(milvus::knowhere::meta::IDS) != dataset->data().end());
|
||||
auto index_type = get_index_type();
|
||||
auto index_mode = get_index_mode();
|
||||
config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type);
|
||||
AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!");
|
||||
// index_->Train(dataset, config_);
|
||||
// index_->Add(dataset, config_);
|
||||
index_->BuildAll(dataset, config_);
|
||||
|
@ -281,6 +289,16 @@ IndexWrapper::get_metric_type() {
|
|||
}
|
||||
}
|
||||
|
||||
knowhere::IndexMode
|
||||
IndexWrapper::get_index_mode() {
|
||||
static std::map<std::string, knowhere::IndexMode> mode_map = {
|
||||
{"CPU", knowhere::IndexMode::MODE_CPU},
|
||||
{"GPU", knowhere::IndexMode::MODE_GPU},
|
||||
};
|
||||
auto mode = get_config_by_name<std::string>("index_mode");
|
||||
return mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU;
|
||||
}
|
||||
|
||||
std::unique_ptr<IndexWrapper::QueryResult>
|
||||
IndexWrapper::Query(const knowhere::DatasetPtr& dataset) {
|
||||
return std::move(QueryImpl(dataset, config_));
|
||||
|
|
|
@ -62,6 +62,9 @@ class IndexWrapper {
|
|||
std::string
|
||||
get_metric_type();
|
||||
|
||||
knowhere::IndexMode
|
||||
get_index_mode();
|
||||
|
||||
template <typename T>
|
||||
std::optional<T>
|
||||
get_config_by_name(std::string name);
|
||||
|
|
|
@ -35,7 +35,7 @@ CreateIndex(const char* serialized_type_params, const char* serialized_index_par
|
|||
*res_index = index.release();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float*
|
|||
cIndex->BuildWithoutIds(ds);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ BuildBinaryVecIndexWithoutIds(CIndex index, int64_t data_size, const uint8_t* ve
|
|||
cIndex->BuildWithoutIds(ds);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ SerializeToSlicedBuffer(CIndex index, int32_t* buffer_size, char** res_buffer) {
|
|||
*res_buffer = binary.data;
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -109,7 +109,7 @@ LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, in
|
|||
cIndex->Load(serialized_sliced_blob_buffer, size);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ QueryOnFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors
|
|||
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -153,7 +153,7 @@ QueryOnFloatVecIndexWithParam(CIndex index,
|
|||
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ QueryOnBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors, C
|
|||
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -197,7 +197,7 @@ QueryOnBinaryVecIndexWithParam(CIndex index,
|
|||
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -213,7 +213,7 @@ CreateQueryResult(CIndexQueryResult* res) {
|
|||
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
@ -259,7 +259,7 @@ DeleteQueryResult(CIndexQueryResult res) {
|
|||
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
|
|
|
@ -106,10 +106,13 @@ func (index *CIndex) BuildFloatVecIndexWithoutIds(vectors []float32) error {
|
|||
CStatus
|
||||
BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float* vectors);
|
||||
*/
|
||||
fmt.Println("before BuildFloatVecIndexWithoutIds")
|
||||
status := C.BuildFloatVecIndexWithoutIds(index.indexPtr, (C.int64_t)(len(vectors)), (*C.float)(&vectors[0]))
|
||||
errorCode := status.error_code
|
||||
fmt.Println("BuildFloatVecIndexWithoutIds error code: ", errorCode)
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
fmt.Println("BuildFloatVecIndexWithoutIds error msg: ", errorMsg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New("BuildFloatVecIndexWithoutIds failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
}
|
||||
|
|
|
@ -440,7 +440,7 @@ func (qt *QueryTask) PostExecute() error {
|
|||
|
||||
hits := make([][]*servicepb.Hits, 0)
|
||||
for _, partialSearchResult := range filterSearchResult {
|
||||
if len(partialSearchResult.Hits) <= 0 {
|
||||
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
|
||||
filterReason += "nq is zero\n"
|
||||
continue
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue