mirror of https://github.com/milvus-io/milvus.git
Get SIMD type used in faiss (#8849)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/8897/head
parent
546499ff63
commit
a10f421c14
|
@ -14,6 +14,7 @@
|
|||
#include "knowhere/archive/KnowhereConfig.h"
|
||||
#include "easyloggingpp/easylogging++.h"
|
||||
#include "ConfigKnowhere.h"
|
||||
#include "faiss/FaissHook.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace config {
|
||||
|
@ -36,7 +37,7 @@ KnowhereInitImpl() {
|
|||
std::call_once(init_knowhere_once_, init);
|
||||
}
|
||||
|
||||
void
|
||||
std::string
|
||||
KnowhereSetSimdType(const char* value) {
|
||||
milvus::engine::KnowhereConfig::SimdType simd_type;
|
||||
if (strcmp(value, "auto") == 0) {
|
||||
|
@ -50,7 +51,7 @@ KnowhereSetSimdType(const char* value) {
|
|||
} else {
|
||||
PanicInfo("invalid SIMD type: " + std::string(value));
|
||||
}
|
||||
milvus::engine::KnowhereConfig::SetSimdType(simd_type);
|
||||
return milvus::engine::KnowhereConfig::SetSimdType(simd_type);
|
||||
}
|
||||
|
||||
} // namespace config
|
||||
|
|
|
@ -17,7 +17,7 @@ namespace milvus::config {
|
|||
void
|
||||
KnowhereInitImpl();
|
||||
|
||||
void
|
||||
std::string
|
||||
KnowhereSetSimdType(const char*);
|
||||
|
||||
} // namespace milvus::config
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "utils/ConfigUtils.h"
|
||||
#include "utils/Error.h"
|
||||
#include "utils/Log.h"
|
||||
#include "index/knowhere/knowhere/common/Exception.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -36,7 +37,7 @@ namespace engine {
|
|||
|
||||
constexpr int64_t M_BYTE = 1024 * 1024;
|
||||
|
||||
Status
|
||||
std::string
|
||||
KnowhereConfig::SetSimdType(const SimdType simd_type) {
|
||||
if (simd_type == SimdType::AVX512) {
|
||||
faiss::faiss_use_avx512 = true;
|
||||
|
@ -58,12 +59,11 @@ KnowhereConfig::SetSimdType(const SimdType simd_type) {
|
|||
|
||||
std::string cpu_flag;
|
||||
if (faiss::hook_init(cpu_flag)) {
|
||||
std::cout << "FAISS hook " << cpu_flag << std::endl;
|
||||
LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag;
|
||||
return Status::OK();
|
||||
return cpu_flag;
|
||||
}
|
||||
|
||||
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
|
||||
KNOWHERE_THROW_MSG("FAISS hook fail, CPU not supported!");
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "utils/Status.h"
|
||||
|
||||
|
@ -30,7 +31,7 @@ class KnowhereConfig {
|
|||
AVX512, // only enable AVX512
|
||||
};
|
||||
|
||||
static Status
|
||||
static std::string
|
||||
SetSimdType(const SimdType simd_type);
|
||||
|
||||
/**
|
||||
|
|
|
@ -31,9 +31,9 @@ class KnowhereException : public std::exception {
|
|||
|
||||
#define KNOHWERE_ERROR_MSG(MSG) printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
|
||||
|
||||
#define KNOWHERE_THROW_MSG(MSG) \
|
||||
do { \
|
||||
throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
||||
#define KNOWHERE_THROW_MSG(MSG) \
|
||||
do { \
|
||||
throw milvus::knowhere::KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
||||
} while (false)
|
||||
|
||||
#define KNOHERE_THROW_FORMAT(FMT, ...) \
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <string.h>
|
||||
#include "config/ConfigKnowhere.h"
|
||||
#include "indexbuilder/init_c.h"
|
||||
|
||||
|
@ -17,7 +18,12 @@ IndexBuilderInit() {
|
|||
milvus::config::KnowhereInitImpl();
|
||||
}
|
||||
|
||||
void
|
||||
// return value must be freed by the caller
|
||||
char*
|
||||
IndexBuilderSetSimdType(const char* value) {
|
||||
milvus::config::KnowhereSetSimdType(value);
|
||||
auto real_type = milvus::config::KnowhereSetSimdType(value);
|
||||
char* ret = reinterpret_cast<char*>(malloc(real_type.length() + 1));
|
||||
memcpy(ret, real_type.c_str(), real_type.length());
|
||||
ret[real_type.length()] = 0;
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -18,7 +18,8 @@ extern "C" {
|
|||
void
|
||||
IndexBuilderInit();
|
||||
|
||||
void
|
||||
// return value must be freed by the caller
|
||||
char*
|
||||
IndexBuilderSetSimdType(const char*);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -29,10 +29,15 @@ SegcoreSetChunkRows(const int64_t value) {
|
|||
LOG_SEGCORE_DEBUG_ << "set config chunk_size: " << config.get_chunk_rows();
|
||||
}
|
||||
|
||||
extern "C" void
|
||||
// return value must be freed by the caller
|
||||
extern "C" char*
|
||||
SegcoreSetSimdType(const char* value) {
|
||||
milvus::config::KnowhereSetSimdType(value);
|
||||
LOG_SEGCORE_DEBUG_ << "set config simd_type: " << value;
|
||||
auto real_type = milvus::config::KnowhereSetSimdType(value);
|
||||
char* ret = reinterpret_cast<char*>(malloc(real_type.length() + 1));
|
||||
memcpy(ret, real_type.c_str(), real_type.length());
|
||||
ret[real_type.length()] = 0;
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -21,7 +21,8 @@ SegcoreInit();
|
|||
void
|
||||
SegcoreSetChunkRows(const int64_t);
|
||||
|
||||
void
|
||||
// return value must be freed by the caller
|
||||
char*
|
||||
SegcoreSetSimdType(const char*);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -114,9 +114,11 @@ func (i *IndexNode) Register() error {
|
|||
func (i *IndexNode) initKnowhere() {
|
||||
C.IndexBuilderInit()
|
||||
|
||||
// override segcore SIMD type
|
||||
// override index builder SIMD type
|
||||
cSimdType := C.CString(Params.SimdType)
|
||||
C.IndexBuilderSetSimdType(cSimdType)
|
||||
cRealSimdType := C.IndexBuilderSetSimdType(cSimdType)
|
||||
Params.SimdType = C.GoString(cRealSimdType)
|
||||
C.free(unsafe.Pointer(cRealSimdType))
|
||||
C.free(unsafe.Pointer(cSimdType))
|
||||
}
|
||||
|
||||
|
|
|
@ -124,7 +124,9 @@ func (node *QueryNode) InitSegcore() {
|
|||
|
||||
// override segcore SIMD type
|
||||
cSimdType := C.CString(Params.SimdType)
|
||||
C.SegcoreSetSimdType(cSimdType)
|
||||
cRealSimdType := C.SegcoreSetSimdType(cSimdType)
|
||||
Params.SimdType = C.GoString(cRealSimdType)
|
||||
C.free(unsafe.Pointer(cRealSimdType))
|
||||
C.free(unsafe.Pointer(cSimdType))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue