fix large param (#3684)

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/3710/head
shengjun.li 2020-09-10 19:32:51 +08:00 committed by GitHub
parent dad5806b61
commit 0af3804974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 8 deletions

View File

@ -84,7 +84,7 @@ MatchNlist(int64_t size, int64_t nlist) {
bool
IVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MAX_NLIST = 999999;
static int64_t MAX_NLIST = 65536;
static int64_t MIN_NLIST = 1;
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
@ -109,7 +109,7 @@ IVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
bool
IVFConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) {
static int64_t MIN_NPROBE = 1;
static int64_t MAX_NPROBE = 999999; // todo(linxj): [1, nlist]
static int64_t MAX_NPROBE = 65536; // todo(linxj): [1, nlist]
if (mode == IndexMode::MODE_GPU) {
#ifdef MILVUS_GPU_VERSION
@ -133,7 +133,7 @@ IVFSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
bool
IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t DEFAULT_NBITS = 8;
static int64_t MAX_NLIST = 999999;
static int64_t MAX_NLIST = 65536;
static int64_t MIN_NLIST = 1;
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::IP};

View File

@ -24,6 +24,7 @@
#endif
#include <fiu-local.h>
#include <algorithm>
#include <chrono>
#include <memory>
#include <string>
@ -321,7 +322,7 @@ void
IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist);
stdclock::time_point before = stdclock::now();
if (params->nprobe > 1 && n <= 4) {
ivf_index->parallel_mode = 1;

View File

@ -9,6 +9,7 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <algorithm>
#include <memory>
#include <faiss/gpu/GpuCloner.h>
@ -140,7 +141,7 @@ GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
fiu_do_on("GPUIVF.search_impl.invald_index", device_index = nullptr);
if (device_index) {
device_index->nprobe = config[IndexParams::nprobe];
device_index->nprobe = std::min(static_cast<int>(config[IndexParams::nprobe]), device_index->nlist);
ResScope rs(res_, gpu_id_);
// if query size > 2048 we search by blocks to avoid malloc issue

View File

@ -152,6 +152,9 @@ static void knn_inner_product_sse (const float * x,
size_t block_x = std::min(
get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))),
nx);
if (block_x == 0) {
block_x = 1;
}
size_t all_heap_size = block_x * k * thread_max_num;
float *value = new float[all_heap_size];
@ -261,6 +264,9 @@ static void knn_L2sqr_sse (
size_t block_x = std::min(
get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))),
nx);
if (block_x == 0) {
block_x = 1;
}
size_t all_heap_size = block_x * k * thread_max_num;
float *value = new float[all_heap_size];

View File

@ -197,14 +197,14 @@ ValidationUtil::ValidateIndexParams(const milvus::json& index_params,
case (int32_t)engine::EngineType::FAISS_IVFSQ8:
case (int32_t)engine::EngineType::FAISS_IVFSQ8H:
case (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT: {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 999999);
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 65536);
if (!status.ok()) {
return status;
}
break;
}
case (int32_t)engine::EngineType::FAISS_PQ: {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 999999);
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 65536);
if (!status.ok()) {
return status;
}
@ -292,7 +292,7 @@ ValidationUtil::ValidateSearchParams(const milvus::json& search_params,
case (int32_t)engine::EngineType::FAISS_IVFSQ8H:
case (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT:
case (int32_t)engine::EngineType::FAISS_PQ: {
auto status = CheckParameterRange(search_params, knowhere::IndexParams::nprobe, 1, 999999);
auto status = CheckParameterRange(search_params, knowhere::IndexParams::nprobe, 1, 65536);
if (!status.ok()) {
return status;
}