fix unitest crash

Former-commit-id: 977fe7218389c0cc892f70b6499ffca834e54283
pull/191/head
starlord 2019-08-31 18:07:21 +08:00
parent 03bcdfe4f4
commit 8c63489df3
1 changed files with 10 additions and 8 deletions

View File

@ -21,11 +21,16 @@ using ::testing::TestWithParam;
using ::testing::Values; using ::testing::Values;
using ::testing::Combine; using ::testing::Combine;
constexpr int64_t DIM = 512;
constexpr int64_t NB = 1000000;
class KnowhereWrapperTest class KnowhereWrapperTest
: public TestWithParam<::std::tuple<IndexType, std::string, int, int, int, int, Config, Config>> { : public TestWithParam<::std::tuple<IndexType, std::string, int, int, int, int, Config, Config>> {
protected: protected:
void SetUp() override { void SetUp() override {
zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(0);
zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(1);
std::string generator_type; std::string generator_type;
std::tie(index_type, generator_type, dim, nb, nq, k, train_cfg, search_cfg) = GetParam(); std::tie(index_type, generator_type, dim, nb, nq, k, train_cfg, search_cfg) = GetParam();
@ -67,8 +72,8 @@ class KnowhereWrapperTest
Config train_cfg; Config train_cfg;
Config search_cfg; Config search_cfg;
int dim = 512; int dim = DIM;
int nb = 1000000; int nb = NB;
int nq = 10; int nq = 10;
int k = 10; int k = 10;
std::vector<float> xb; std::vector<float> xb;
@ -106,9 +111,9 @@ INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest,
// Config::object{{"dim", 64}, {"k", 10}} // Config::object{{"dim", 64}, {"k", 10}}
// ), // ),
std::make_tuple(IndexType::FAISS_IVFSQ8_MIX, "Default", std::make_tuple(IndexType::FAISS_IVFSQ8_MIX, "Default",
512, 1000000, 10, 10, DIM, NB, 10, 10,
Config::object{{"dim", 512}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}}, Config::object{{"dim", DIM}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}},
Config::object{{"dim", 512}, {"k", 10}, {"nprobe", 5}} Config::object{{"dim", DIM}, {"k", 10}, {"nprobe", 5}}
) )
// std::make_tuple(IndexType::NSG_MIX, "Default", // std::make_tuple(IndexType::NSG_MIX, "Default",
// 128, 250000, 10, 10, // 128, 250000, 10, 10,
@ -139,9 +144,6 @@ TEST_P(KnowhereWrapperTest, base_test) {
TEST_P(KnowhereWrapperTest, to_gpu_test) { TEST_P(KnowhereWrapperTest, to_gpu_test) {
EXPECT_EQ(index_->GetType(), index_type); EXPECT_EQ(index_->GetType(), index_type);
zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(0);
zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(1);
auto elems = nq * k; auto elems = nq * k;
std::vector<int64_t> res_ids(elems); std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems); std::vector<float> res_dis(elems);