Use early stop strategy in K-Means (#4608)

* Use early stop strategy

Signed-off-by: sahuang <xiaohaix@student.unimelb.edu.au>
pull/4617/head
Xiaohai Xu 2021-01-13 14:30:17 +08:00 committed by GitHub
parent bb79e896aa
commit 91a1494b57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 19 additions and 3 deletions

View File

@ -71,8 +71,8 @@ KnowhereResource::Initialize() {
}
// init faiss global variable
int64_t use_blas_threshold = config.engine.use_blas_threshold();
faiss::distance_compute_blas_threshold = use_blas_threshold;
faiss::distance_compute_blas_threshold = config.engine.use_blas_threshold();
faiss::early_stop_threshold = config.engine.early_stop_threshold();
int64_t clustering_type = config.engine.clustering_type();
switch (clustering_type) {

View File

@ -262,6 +262,7 @@ int split_clusters (size_t d, size_t k, size_t n,
};
ClusteringType clustering_type = ClusteringType::K_MEANS;
double early_stop_threshold = 0.0;
void Clustering::kmeans_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
size_t n_input_centroids, size_t d, size_t k,
@ -532,6 +533,7 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
// k-means iterations
float err = 0;
float prev_objective = 0;
for (int i = 0; i < niter; i++) {
double t0s = getmillisecs();
@ -600,6 +602,14 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
}
index.add (k, centroids.data());
// Early stop strategy
float diff = (prev_objective == 0) ? std::numeric_limits<float>::max() : (prev_objective - stats.obj) / prev_objective;
prev_objective = stats.obj;
if (diff < early_stop_threshold / 100.) {
break;
}
InterruptCallback::check ();
}

View File

@ -25,9 +25,12 @@ enum ClusteringType
K_MEANS_TWO,
};
//The default algorithm use the K_MEANS
// The default algorithm use the K_MEANS
extern ClusteringType clustering_type;
// K-Means Early Stop Threshold; defaults to 0.0
extern double early_stop_threshold;
/** Class for the clustering parameters. Can be passed to the
* constructor of the Clustering object.

View File

@ -191,6 +191,7 @@ InitConfig() {
Enum(engine.clustering_type, &ClusteringMap, ClusteringType::K_MEANS),
Enum(engine.simd_type, &SimdMap, SimdType::AUTO),
Integer(engine.statistics_level, 0, 3, 1),
Floating(engine.early_stop_threshold, 0.0, 100.0, 0.0),
Bool(system.lock.enable, true),

View File

@ -69,6 +69,7 @@ ConfigMgr::ConfigMgr() : ValueMgr(InitConfig()) {
"engine.search_combine_nq",
"engine.use_blas_threshold",
"engine.omp_thread_num",
"engine.early_stop_threshold",
};
}

View File

@ -110,6 +110,7 @@ struct ServerConfig {
Integer clustering_type;
Integer simd_type;
Integer statistics_level;
Floating early_stop_threshold;
} engine;
struct GPU {