mirror of https://github.com/milvus-io/milvus.git
Use early stop strategy in K-Means (#4608)
* Use early stop strategy Signed-off-by: sahuang <xiaohaix@student.unimelb.edu.au>pull/4617/head
parent
bb79e896aa
commit
91a1494b57
|
@ -71,8 +71,8 @@ KnowhereResource::Initialize() {
|
|||
}
|
||||
|
||||
// init faiss global variable
|
||||
int64_t use_blas_threshold = config.engine.use_blas_threshold();
|
||||
faiss::distance_compute_blas_threshold = use_blas_threshold;
|
||||
faiss::distance_compute_blas_threshold = config.engine.use_blas_threshold();
|
||||
faiss::early_stop_threshold = config.engine.early_stop_threshold();
|
||||
|
||||
int64_t clustering_type = config.engine.clustering_type();
|
||||
switch (clustering_type) {
|
||||
|
|
|
@ -262,6 +262,7 @@ int split_clusters (size_t d, size_t k, size_t n,
|
|||
};
|
||||
|
||||
ClusteringType clustering_type = ClusteringType::K_MEANS;
|
||||
double early_stop_threshold = 0.0;
|
||||
|
||||
void Clustering::kmeans_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
|
||||
size_t n_input_centroids, size_t d, size_t k,
|
||||
|
@ -532,6 +533,7 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|||
// k-means iterations
|
||||
|
||||
float err = 0;
|
||||
float prev_objective = 0;
|
||||
for (int i = 0; i < niter; i++) {
|
||||
double t0s = getmillisecs();
|
||||
|
||||
|
@ -600,6 +602,14 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|||
}
|
||||
|
||||
index.add (k, centroids.data());
|
||||
|
||||
// Early stop strategy
|
||||
float diff = (prev_objective == 0) ? std::numeric_limits<float>::max() : (prev_objective - stats.obj) / prev_objective;
|
||||
prev_objective = stats.obj;
|
||||
if (diff < early_stop_threshold / 100.) {
|
||||
break;
|
||||
}
|
||||
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
|
||||
|
|
|
@ -25,9 +25,12 @@ enum ClusteringType
|
|||
K_MEANS_TWO,
|
||||
};
|
||||
|
||||
//The default algorithm use the K_MEANS
|
||||
// The default algorithm use the K_MEANS
|
||||
extern ClusteringType clustering_type;
|
||||
|
||||
// K-Means Early Stop Threshold; defaults to 0.0
|
||||
extern double early_stop_threshold;
|
||||
|
||||
|
||||
/** Class for the clustering parameters. Can be passed to the
|
||||
* constructor of the Clustering object.
|
||||
|
|
|
@ -191,6 +191,7 @@ InitConfig() {
|
|||
Enum(engine.clustering_type, &ClusteringMap, ClusteringType::K_MEANS),
|
||||
Enum(engine.simd_type, &SimdMap, SimdType::AUTO),
|
||||
Integer(engine.statistics_level, 0, 3, 1),
|
||||
Floating(engine.early_stop_threshold, 0.0, 100.0, 0.0),
|
||||
|
||||
Bool(system.lock.enable, true),
|
||||
|
||||
|
|
|
@ -69,6 +69,7 @@ ConfigMgr::ConfigMgr() : ValueMgr(InitConfig()) {
|
|||
"engine.search_combine_nq",
|
||||
"engine.use_blas_threshold",
|
||||
"engine.omp_thread_num",
|
||||
"engine.early_stop_threshold",
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -110,6 +110,7 @@ struct ServerConfig {
|
|||
Integer clustering_type;
|
||||
Integer simd_type;
|
||||
Integer statistics_level;
|
||||
Floating early_stop_threshold;
|
||||
} engine;
|
||||
|
||||
struct GPU {
|
||||
|
|
Loading…
Reference in New Issue