Get SIMD type used in faiss (#8849)

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/8897/head
dragondriver 2021-09-29 20:50:19 +08:00 committed by GitHub
parent 546499ff63
commit a10f421c14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 39 additions and 20 deletions

View File

@ -14,6 +14,7 @@
#include "knowhere/archive/KnowhereConfig.h"
#include "easyloggingpp/easylogging++.h"
#include "ConfigKnowhere.h"
#include "faiss/FaissHook.h"
namespace milvus {
namespace config {
@ -36,7 +37,7 @@ KnowhereInitImpl() {
std::call_once(init_knowhere_once_, init);
}
void
std::string
KnowhereSetSimdType(const char* value) {
milvus::engine::KnowhereConfig::SimdType simd_type;
if (strcmp(value, "auto") == 0) {
@ -50,7 +51,7 @@ KnowhereSetSimdType(const char* value) {
} else {
PanicInfo("invalid SIMD type: " + std::string(value));
}
milvus::engine::KnowhereConfig::SetSimdType(simd_type);
return milvus::engine::KnowhereConfig::SetSimdType(simd_type);
}
} // namespace config

View File

@ -17,7 +17,7 @@ namespace milvus::config {
void
KnowhereInitImpl();
void
std::string
KnowhereSetSimdType(const char*);
} // namespace milvus::config

View File

@ -27,6 +27,7 @@
#include "utils/ConfigUtils.h"
#include "utils/Error.h"
#include "utils/Log.h"
#include "index/knowhere/knowhere/common/Exception.h"
#include <string>
#include <vector>
@ -36,7 +37,7 @@ namespace engine {
constexpr int64_t M_BYTE = 1024 * 1024;
Status
std::string
KnowhereConfig::SetSimdType(const SimdType simd_type) {
if (simd_type == SimdType::AVX512) {
faiss::faiss_use_avx512 = true;
@ -58,12 +59,11 @@ KnowhereConfig::SetSimdType(const SimdType simd_type) {
std::string cpu_flag;
if (faiss::hook_init(cpu_flag)) {
std::cout << "FAISS hook " << cpu_flag << std::endl;
LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag;
return Status::OK();
return cpu_flag;
}
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
KNOWHERE_THROW_MSG("FAISS hook fail, CPU not supported!");
}
void

View File

@ -12,6 +12,7 @@
#pragma once
#include <vector>
#include <string>
#include "utils/Status.h"
@ -30,7 +31,7 @@ class KnowhereConfig {
AVX512, // only enable AVX512
};
static Status
static std::string
SetSimdType(const SimdType simd_type);
/**

View File

@ -31,9 +31,9 @@ class KnowhereException : public std::exception {
#define KNOHWERE_ERROR_MSG(MSG) printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
#define KNOWHERE_THROW_MSG(MSG) \
do { \
throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
#define KNOWHERE_THROW_MSG(MSG) \
do { \
throw milvus::knowhere::KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
} while (false)
#define KNOHERE_THROW_FORMAT(FMT, ...) \

View File

@ -9,6 +9,7 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <string.h>
#include "config/ConfigKnowhere.h"
#include "indexbuilder/init_c.h"
@ -17,7 +18,12 @@ IndexBuilderInit() {
milvus::config::KnowhereInitImpl();
}
void
// return value must be freed by the caller
char*
IndexBuilderSetSimdType(const char* value) {
milvus::config::KnowhereSetSimdType(value);
auto real_type = milvus::config::KnowhereSetSimdType(value);
char* ret = reinterpret_cast<char*>(malloc(real_type.length() + 1));
memcpy(ret, real_type.c_str(), real_type.length());
ret[real_type.length()] = 0;
return ret;
}

View File

@ -18,7 +18,8 @@ extern "C" {
void
IndexBuilderInit();
void
// return value must be freed by the caller
char*
IndexBuilderSetSimdType(const char*);
#ifdef __cplusplus

View File

@ -29,10 +29,15 @@ SegcoreSetChunkRows(const int64_t value) {
LOG_SEGCORE_DEBUG_ << "set config chunk_size: " << config.get_chunk_rows();
}
extern "C" void
// return value must be freed by the caller
extern "C" char*
SegcoreSetSimdType(const char* value) {
milvus::config::KnowhereSetSimdType(value);
LOG_SEGCORE_DEBUG_ << "set config simd_type: " << value;
auto real_type = milvus::config::KnowhereSetSimdType(value);
char* ret = reinterpret_cast<char*>(malloc(real_type.length() + 1));
memcpy(ret, real_type.c_str(), real_type.length());
ret[real_type.length()] = 0;
return ret;
}
} // namespace milvus::segcore

View File

@ -21,7 +21,8 @@ SegcoreInit();
void
SegcoreSetChunkRows(const int64_t);
void
// return value must be freed by the caller
char*
SegcoreSetSimdType(const char*);
#ifdef __cplusplus

View File

@ -114,9 +114,11 @@ func (i *IndexNode) Register() error {
func (i *IndexNode) initKnowhere() {
C.IndexBuilderInit()
// override segcore SIMD type
// override index builder SIMD type
cSimdType := C.CString(Params.SimdType)
C.IndexBuilderSetSimdType(cSimdType)
cRealSimdType := C.IndexBuilderSetSimdType(cSimdType)
Params.SimdType = C.GoString(cRealSimdType)
C.free(unsafe.Pointer(cRealSimdType))
C.free(unsafe.Pointer(cSimdType))
}

View File

@ -124,7 +124,9 @@ func (node *QueryNode) InitSegcore() {
// override segcore SIMD type
cSimdType := C.CString(Params.SimdType)
C.SegcoreSetSimdType(cSimdType)
cRealSimdType := C.SegcoreSetSimdType(cSimdType)
Params.SimdType = C.GoString(cRealSimdType)
C.free(unsafe.Pointer(cRealSimdType))
C.free(unsafe.Pointer(cSimdType))
}