mirror of https://github.com/milvus-io/milvus.git
parent
1962b6e78b
commit
83cbe0f490
|
@ -53,64 +53,13 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
// QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, Config());
|
||||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
auto pf_dist = (float*)malloc(p_dist_size);
|
||||
int32_t* pi_dist = (int32_t*)p_dist;
|
||||
for (int i = 0; i < elems; i++) {
|
||||
*(pf_dist + i) = (float)(*(pi_dist + i));
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, pf_dist);
|
||||
free(p_dist);
|
||||
} else {
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
}
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
auto dim = dataset_ptr->Get<int64_t>(meta::DIM);
|
||||
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
int64_t k = config[meta::TOPK].get<int64_t>();
|
||||
auto elems = rows * k;
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
auto* pdistances = (int32_t*)p_dist;
|
||||
index_->search_by_id(rows, p_data, k, pdistances, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
auto pf_dist = (float*)malloc(p_dist_size);
|
||||
int32_t* pi_dist = (int32_t*)p_dist;
|
||||
for (int i = 0; i < elems; i++) {
|
||||
*(pf_dist + i) = (float)(*(pi_dist + i));
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, pf_dist);
|
||||
free(p_dist);
|
||||
} else {
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
|
||||
return ret_ds;
|
||||
}
|
||||
#endif
|
||||
|
||||
int64_t
|
||||
BinaryIDMAP::Count() {
|
||||
|
@ -187,39 +136,25 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config)
|
|||
index_->add_with_ids(rows, (uint8_t*)p_data, new_ids.data());
|
||||
}
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
// GETBINARYTENSOR(dataset_ptr)
|
||||
// auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
|
||||
|
||||
size_t p_x_size = sizeof(uint8_t) * elems;
|
||||
auto p_x = (uint8_t*)malloc(p_x_size);
|
||||
|
||||
index_->get_vector_by_id(1, p_data, p_x, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::TENSOR, p_x);
|
||||
return ret_ds;
|
||||
}
|
||||
#endif
|
||||
|
||||
void
|
||||
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
int32_t* pdistances = (int32_t*)distances;
|
||||
|
||||
auto flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
|
||||
auto default_type = flat_index->metric_type;
|
||||
if (config.contains(Metric::TYPE))
|
||||
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
|
||||
|
||||
int32_t* i_distances = reinterpret_cast<int32_t*>(distances);
|
||||
index_->search(n, (uint8_t*)data, k, i_distances, labels, bitset_);
|
||||
|
||||
// if hamming, it need transform int32 to float
|
||||
if (flat_index->metric_type == faiss::METRIC_Hamming) {
|
||||
int64_t num = n * k;
|
||||
for (int64_t i = 0; i < num; i++) {
|
||||
distances[i] = static_cast<float>(i_distances[i]);
|
||||
}
|
||||
}
|
||||
|
||||
flat_index->metric_type = default_type;
|
||||
}
|
||||
|
||||
|
|
|
@ -50,11 +50,6 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
|||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
QueryById(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
#endif
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
|
@ -66,11 +61,6 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
|||
return Count() * Dim() / 8;
|
||||
}
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
#endif
|
||||
|
||||
virtual const uint8_t*
|
||||
GetRawVectors();
|
||||
|
||||
|
|
|
@ -62,19 +62,9 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
auto pf_dist = (float*)malloc(p_dist_size);
|
||||
int32_t* pi_dist = (int32_t*)p_dist;
|
||||
for (int i = 0; i < elems; i++) {
|
||||
*(pf_dist + i) = (float)(*(pi_dist + i));
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, pf_dist);
|
||||
free(p_dist);
|
||||
} else {
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
|
||||
return ret_ds;
|
||||
} catch (faiss::FaissException& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
|
@ -215,11 +205,10 @@ BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances
|
|||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
int32_t* pdistances = (int32_t*)distances;
|
||||
stdclock::time_point before = stdclock::now();
|
||||
|
||||
// todo: remove static cast (zhiru)
|
||||
static_cast<faiss::IndexBinary*>(index_.get())->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
|
||||
stdclock::time_point before = stdclock::now();
|
||||
int32_t* i_distances = reinterpret_cast<int32_t*>(distances);
|
||||
index_->search(n, (uint8_t*)data, k, i_distances, labels, bitset_);
|
||||
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
|
@ -228,6 +217,14 @@ BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances
|
|||
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
|
||||
faiss::indexIVF_stats.quantization_time = 0;
|
||||
faiss::indexIVF_stats.search_time = 0;
|
||||
|
||||
// if hamming, it need transform int32 to float
|
||||
if (ivf_index->metric_type == faiss::METRIC_Hamming) {
|
||||
int64_t num = n * k;
|
||||
for (int64_t i = 0; i < num; i++) {
|
||||
distances[i] = static_cast<float>(i_distances[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -109,7 +109,26 @@ TestProcess(std::shared_ptr<milvus::Connection> connection,
|
|||
TOP_K,
|
||||
NPROBE,
|
||||
search_entity_array,
|
||||
topk_query_result);
|
||||
topk_query_result,
|
||||
milvus::MetricType::HAMMING);
|
||||
|
||||
milvus_sdk::Utils::DoSearch(connection,
|
||||
collection_param.collection_name,
|
||||
partition_tags,
|
||||
TOP_K,
|
||||
NPROBE,
|
||||
search_entity_array,
|
||||
topk_query_result,
|
||||
milvus::MetricType::SUBSTRUCTURE);
|
||||
|
||||
milvus_sdk::Utils::DoSearch(connection,
|
||||
collection_param.collection_name,
|
||||
partition_tags,
|
||||
TOP_K,
|
||||
NPROBE,
|
||||
search_entity_array,
|
||||
topk_query_result,
|
||||
milvus::MetricType::SUPERSTRUCTURE);
|
||||
}
|
||||
|
||||
{ // wait unit build index finish
|
||||
|
@ -170,41 +189,5 @@ ClientTest::Test(const std::string& address, const std::string& port) {
|
|||
TestProcess(connection, collection_param, index_param);
|
||||
}
|
||||
|
||||
{
|
||||
milvus::CollectionParam collection_param = {
|
||||
"collection_2",
|
||||
512, // dimension
|
||||
512, // index file size
|
||||
milvus::MetricType::SUBSTRUCTURE
|
||||
};
|
||||
|
||||
JSON json_params = {};
|
||||
milvus::IndexParam index_param = {
|
||||
collection_param.collection_name,
|
||||
milvus::IndexType::FLAT,
|
||||
json_params.dump()
|
||||
};
|
||||
|
||||
TestProcess(connection, collection_param, index_param);
|
||||
}
|
||||
|
||||
{
|
||||
milvus::CollectionParam collection_param = {
|
||||
"collection_3",
|
||||
128, // dimension
|
||||
1024, // index file size
|
||||
milvus::MetricType::SUPERSTRUCTURE
|
||||
};
|
||||
|
||||
JSON json_params = {};
|
||||
milvus::IndexParam index_param = {
|
||||
collection_param.collection_name,
|
||||
milvus::IndexType::FLAT,
|
||||
json_params.dump()
|
||||
};
|
||||
|
||||
TestProcess(connection, collection_param, index_param);
|
||||
}
|
||||
|
||||
milvus::Connection::Destroy(connection);
|
||||
}
|
||||
|
|
|
@ -202,7 +202,7 @@ void
|
|||
Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& collection_name,
|
||||
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
|
||||
const std::vector<std::pair<int64_t, milvus::Entity>>& entity_array,
|
||||
milvus::TopKQueryResult& topk_query_result) {
|
||||
milvus::TopKQueryResult& topk_query_result, milvus::MetricType metric_type) {
|
||||
topk_query_result.clear();
|
||||
|
||||
std::vector<milvus::Entity> temp_entity_array;
|
||||
|
@ -213,6 +213,10 @@ Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& col
|
|||
{
|
||||
BLOCK_SPLITER
|
||||
JSON json_params = {{"nprobe", nprobe}};
|
||||
if (metric_type != milvus::MetricType::INVALID) {
|
||||
json_params["metric_type"] = metric_type;
|
||||
}
|
||||
|
||||
milvus_sdk::TimeRecorder rc("Search");
|
||||
milvus::Status stat =
|
||||
conn->Search(collection_name,
|
||||
|
|
|
@ -69,7 +69,8 @@ class Utils {
|
|||
DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& collection_name,
|
||||
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
|
||||
const std::vector<std::pair<int64_t, milvus::Entity>>& entity_array,
|
||||
milvus::TopKQueryResult& topk_query_result);
|
||||
milvus::TopKQueryResult& topk_query_result,
|
||||
milvus::MetricType metric_type = milvus::MetricType::INVALID);
|
||||
|
||||
static std::vector<milvus::LeafQueryPtr>
|
||||
GenLeafQuery();
|
||||
|
|
|
@ -42,6 +42,7 @@ enum class IndexType {
|
|||
};
|
||||
|
||||
enum class MetricType {
|
||||
INVALID = 0,
|
||||
L2 = 1, // Euclidean Distance
|
||||
IP = 2, // Cosine Similarity
|
||||
HAMMING = 3, // Hamming Distance
|
||||
|
|
Loading…
Reference in New Issue