Let uid be shared pointer to not copy really in 0.10.4 (#4064)

* Let uid be shared pointer to not copy really in 0.10.4

Signed-off-by: cqy <yaya645@126.com>
pull/4094/head
cqy123456 2020-10-23 19:30:13 +08:00 committed by GitHub
parent 8d5c1a87d5
commit 52dbf7ed3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 33 deletions

View File

@ -428,9 +428,10 @@ ExecutionEngineImpl::Load(bool to_cache) {
auto& deleted_docs = segment_ptr->deleted_docs_ptr_->GetDeletedDocs();
auto& vectors_uids = vectors->GetMutableUids();
auto count = vectors_uids.size();
index_->SetUids(vectors_uids);
LOG_ENGINE_DEBUG_ << "set uids " << index_->GetUids().size() << " for index " << location_;
std::shared_ptr<std::vector<int64_t>> vector_uids_ptr = std::make_shared<std::vector<int64_t>>();
vector_uids_ptr->swap(vectors_uids);
index_->SetUids(vector_uids_ptr);
LOG_ENGINE_DEBUG_ << "set uids " << vector_uids_ptr->size() << " for index " << location_;
auto& vectors_data = vectors->GetData();
@ -442,6 +443,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
attr_size_.insert(std::pair(attrs_it->first, attrs_it->second->GetNbytes()));
}
auto count = vector_uids_ptr->size();
vector_count_ = count;
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = nullptr;
@ -497,11 +499,10 @@ ExecutionEngineImpl::Load(bool to_cache) {
}
}
index_->SetBlacklist(concurrent_bitset_ptr);
std::vector<segment::doc_id_t> uids;
segment_reader_ptr->LoadUids(uids);
index_->SetUids(uids);
LOG_ENGINE_DEBUG_ << "set uids " << index_->GetUids().size() << " for index " << location_;
std::shared_ptr<std::vector<int64_t>> uids_ptr = std::make_shared<std::vector<int64_t>>();
segment_reader_ptr->LoadUids(*uids_ptr);
index_->SetUids(uids_ptr);
LOG_ENGINE_DEBUG_ << "set uids " << index_->GetUids()->size() << " for index " << location_;
LOG_ENGINE_DEBUG_ << "Finished loading index file from segment " << segment_dir;
}
@ -681,10 +682,7 @@ ExecutionEngineImpl::CopyToFpga() {
indexFpga->SetIndexSize(indexsize);
indexFpga->CopyIndexToFpga();
indexFpga->SetBlacklist(index_->GetBlacklist());
// do real copy now, may optimizer later
auto uids = index_->GetUids();
indexFpga->SetUids(uids);
indexFpga->SetUids(index_->GetUids());
index_ = indexFpga;
FpgaCache();
@ -722,8 +720,7 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
throw Exception(DB_ERROR, "Illegal index params");
}
LOG_ENGINE_DEBUG_ << "Index config: " << conf.dump();
std::vector<segment::doc_id_t> uids;
std::shared_ptr<std::vector<segment::doc_id_t>> uids;
faiss::ConcurrentBitsetPtr blacklist;
if (from_index) {
auto dataset =
@ -747,7 +744,7 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
}
#endif
to_index->SetUids(uids);
LOG_ENGINE_DEBUG_ << "Set " << to_index->GetUids().size() << "uids for " << location;
LOG_ENGINE_DEBUG_ << "Set " << to_index->GetUids()->size() << "uids for " << location;
if (blacklist != nullptr) {
to_index->SetBlacklist(blacklist);
LOG_ENGINE_DEBUG_ << "Set blacklist for index " << location;
@ -757,8 +754,8 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
}
void
MapAndCopyResult(const knowhere::DatasetPtr& dataset, const std::vector<milvus::segment::doc_id_t>& uids, int64_t nq,
int64_t k, float* distances, int64_t* labels) {
MapAndCopyResult(const knowhere::DatasetPtr& dataset, std::shared_ptr<std::vector<milvus::segment::doc_id_t>> uids,
int64_t nq, int64_t k, float* distances, int64_t* labels) {
int64_t* res_ids = dataset->Get<int64_t*>(knowhere::meta::IDS);
float* res_dist = dataset->Get<float*>(knowhere::meta::DISTANCE);
@ -769,7 +766,7 @@ MapAndCopyResult(const knowhere::DatasetPtr& dataset, const std::vector<milvus::
for (int64_t i = 0; i < num; ++i) {
int64_t offset = res_ids[i];
if (offset != -1) {
labels[i] = uids[offset];
labels[i] = (*uids)[offset];
} else {
labels[i] = -1;
}
@ -1164,7 +1161,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, const milvu
auto result = index_->Query(dataset, conf);
rc.RecordSection("query done");
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld] get %ld uids from index %s", "search", 0, index_->GetUids().size(),
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld] get %ld uids from index %s", "search", 0, index_->GetUids()->size(),
location_.c_str());
MapAndCopyResult(result, index_->GetUids(), n, k, distances, labels);
rc.RecordSection("map uids " + std::to_string(n * k));
@ -1205,7 +1202,7 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const mil
auto result = index_->Query(dataset, conf);
rc.RecordSection("query done");
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld] get %ld uids from index %s", "search", 0, index_->GetUids().size(),
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld] get %ld uids from index %s", "search", 0, index_->GetUids()->size(),
location_.c_str());
MapAndCopyResult(result, index_->GetUids(), n, k, distances, labels);
rc.RecordSection("map uids " + std::to_string(n * k));

View File

@ -94,15 +94,14 @@ class VecIndex : public Index {
bitset_ = std::move(bitset_ptr);
}
const std::vector<IDType>&
std::shared_ptr<std::vector<IDType>>
GetUids() const {
return uids_;
}
void
SetUids(std::vector<IDType>& uids) {
uids_.clear();
uids_.swap(uids);
SetUids(std::shared_ptr<std::vector<IDType>> uids) {
uids_ = uids;
}
size_t
@ -113,7 +112,7 @@ class VecIndex : public Index {
size_t
UidsSize() {
return uids_.size() * sizeof(IDType);
return (uids_ == nullptr) ? 0 : (uids_->size() * sizeof(IDType));
}
virtual int64_t
@ -141,7 +140,7 @@ class VecIndex : public Index {
protected:
IndexType index_type_ = "";
IndexMode index_mode_ = IndexMode::MODE_CPU;
std::vector<IDType> uids_;
std::shared_ptr<std::vector<IDType>> uids_ = nullptr;
int64_t index_size_ = -1;
private:

View File

@ -26,9 +26,7 @@ namespace cloner {
void
CopyIndexData(const VecIndexPtr& dst_index, const VecIndexPtr& src_index) {
/* do real copy */
auto uids = src_index->GetUids();
dst_index->SetUids(uids);
dst_index->SetUids(src_index->GetUids());
dst_index->SetBlacklist(src_index->GetBlacklist());
dst_index->SetIndexSize(src_index->IndexSize());

View File

@ -9,6 +9,8 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#define protected public
#include <gtest/gtest.h>
#include <boost/filesystem.hpp>
#include <vector>
@ -35,17 +37,18 @@ CreateExecEngine(const milvus::json& json_params, milvus::engine::MetricType met
json_params);
std::vector<float> data;
std::vector<int64_t> ids;
std::shared_ptr<std::vector<int64_t>> ids = std::make_shared<std::vector<int64_t>>();
data.reserve(ROW_COUNT * DIMENSION);
ids.reserve(ROW_COUNT);
ids->reserve(ROW_COUNT);
for (int64_t i = 0; i < ROW_COUNT; i++) {
ids.push_back(i);
ids->push_back(i);
for (uint16_t k = 0; k < DIMENSION; k++) {
data.push_back(i * DIMENSION + k);
}
}
auto status = engine_ptr->AddWithIds((int64_t)ids.size(), data.data(), ids.data());
auto status = engine_ptr->AddWithIds((int64_t)ids->size(), data.data(), ids->data());
(std::static_pointer_cast<milvus::engine::ExecutionEngineImpl>(engine_ptr))->index_->SetUids(ids);
return engine_ptr;
}