Refine string parameters, avoid coping or deref (#22708)

Signed-off-by: yah01 <yang.cen@zilliz.com>
pull/22730/head
yah01 2023-03-13 17:53:53 +08:00 committed by GitHub
parent 8f3d6e08df
commit a4031da634
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 102 additions and 79 deletions

View File

@ -29,8 +29,8 @@ SuccessCStatus() {
}
inline CStatus
FailureCStatus(ErrorCode error_code, const std::string& str) {
auto str_dup = strdup(str.c_str());
FailureCStatus(ErrorCode error_code, const std::string_view str) {
auto str_dup = strdup(str.data());
return CStatus{error_code, str_dup};
}

View File

@ -25,7 +25,7 @@ DatasetPtr
SortRangeSearchResult(DatasetPtr data_set,
int64_t topk,
int64_t nq,
std::string metric_type) {
const std::string_view metric_type) {
/**
* nq: number of queries;
* lims: the size of lims is nq + 1, lims[i+1] - lims[i] refers to the size of RangeSearch result queries[i]
@ -104,7 +104,7 @@ SortRangeSearchResult(DatasetPtr data_set,
void
CheckRangeSearchParam(float radius,
float range_filter,
std::string metric_type) {
const std::string_view metric_type) {
/*
* IP: 1.0 range_filter radius
* |------------+---------------| min_heap descending_order

View File

@ -20,10 +20,10 @@ DatasetPtr
SortRangeSearchResult(DatasetPtr data_set,
int64_t topk,
int64_t nq,
std::string metric_type);
const std::string_view metric_type);
void
CheckRangeSearchParam(float radius,
float range_filter,
std::string metric_type);
const std::string_view metric_type);
} // namespace milvus

View File

@ -16,27 +16,28 @@
#pragma once
#include <memory>
#include <limits>
#include <string>
#include <utility>
#include <vector>
#include <unordered_map>
#include <tbb/concurrent_unordered_map.h>
#include <tbb/concurrent_unordered_set.h>
#include <NamedType/named_type.hpp>
#include <boost/align/aligned_allocator.hpp>
#include <boost/container/vector.hpp>
#include <boost/dynamic_bitset.hpp>
#include <NamedType/named_type.hpp>
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <variant>
#include <vector>
#include "nlohmann/json.hpp"
#include "knowhere/comp/index_param.h"
#include "knowhere/binaryset.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/dataset.h"
#include "nlohmann/json.hpp"
#include "pb/plan.pb.h"
#include "pb/schema.pb.h"
#include "pb/segcore.pb.h"
#include "pb/plan.pb.h"
namespace milvus {
@ -132,4 +133,23 @@ using IndexType = knowhere::IndexType;
// TODO :: type define milvus index mode, add transfer func from milvus index mode to knowhere index mode
using IndexMode = knowhere::IndexMode;
// Plus 1 because we can't use greater(>) symbol
constexpr size_t REF_SIZE_THRESHOLD = 16 + 1;
template <typename T>
using MayRef = std::conditional_t<!std::is_trivially_copyable_v<T> ||
sizeof(T) >= REF_SIZE_THRESHOLD,
T&,
T>;
template <typename T>
using Parameter = std::
conditional_t<std::is_same_v<T, std::string>, std::string_view, MayRef<T>>;
static_assert(std::is_same_v<int64_t, Parameter<int64_t>>);
static_assert(std::is_same_v<std::string_view, Parameter<std::string>>);
struct LargeType {
int64_t x, y, z;
};
static_assert(std::is_same_v<LargeType&, Parameter<LargeType>>);
} // namespace milvus

View File

@ -91,13 +91,13 @@ GenResultDataset(const int64_t nq,
}
inline bool
PostfixMatch(const std::string_view str, const std::string& postfix) {
PostfixMatch(const std::string_view str, const std::string_view postfix) {
if (postfix.length() > str.length()) {
return false;
}
int offset = str.length() - postfix.length();
auto ret = strncmp(str.data() + offset, postfix.c_str(), postfix.length());
auto ret = strncmp(str.data() + offset, postfix.data(), postfix.length());
if (ret != 0) {
return false;
}
@ -127,8 +127,9 @@ upper_div(int64_t value, int64_t align) {
}
inline bool
IsMetricType(const std::string& str, const knowhere::MetricType& metric_type) {
return !strcasecmp(str.c_str(), metric_type.c_str());
IsMetricType(const std::string_view str,
const knowhere::MetricType& metric_type) {
return !strcasecmp(str.data(), metric_type.c_str());
}
inline bool

View File

@ -21,7 +21,7 @@ namespace milvus::ChunkMangerConfig {
std::string LOCAL_ROOT_PATH = "/tmp/milvus"; // NOLINT
void
SetLocalRootPath(const std::string& path_prefix) {
SetLocalRootPath(const std::string_view path_prefix) {
LOCAL_ROOT_PATH = path_prefix;
}

View File

@ -21,7 +21,7 @@
namespace milvus::ChunkMangerConfig {
void
SetLocalRootPath(const std::string& path_prefix);
SetLocalRootPath(const std::string_view path_prefix);
std::string
GetLocalRootPath();

View File

@ -40,7 +40,7 @@ class StringIndex : public ScalarIndex<std::string> {
}
virtual const TargetBitmapPtr
PrefixMatch(std::string prefix) = 0;
PrefixMatch(const std::string_view prefix) = 0;
};
using StringIndexPtr = std::unique_ptr<StringIndex>;
} // namespace milvus::index

View File

@ -231,7 +231,7 @@ StringIndexMarisa::Range(std::string lower_bound_value,
}
const TargetBitmapPtr
StringIndexMarisa::PrefixMatch(std::string prefix) {
StringIndexMarisa::PrefixMatch(std::string_view prefix) {
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(str_ids_.size());
auto matched = prefix_match(prefix);
for (const auto str_id : matched) {
@ -266,9 +266,9 @@ StringIndexMarisa::fill_offsets() {
}
size_t
StringIndexMarisa::lookup(const std::string& str) {
StringIndexMarisa::lookup(const std::string_view str) {
marisa::Agent agent;
agent.set_query(str.c_str());
agent.set_query(str.data());
if (trie_.lookup(agent)) {
return agent.key().id();
}
@ -278,10 +278,10 @@ StringIndexMarisa::lookup(const std::string& str) {
}
std::vector<size_t>
StringIndexMarisa::prefix_match(const std::string& prefix) {
StringIndexMarisa::prefix_match(const std::string_view prefix) {
std::vector<size_t> ret;
marisa::Agent agent;
agent.set_query(prefix.c_str());
agent.set_query(prefix.data());
while (trie_.predictive_search(agent)) {
ret.push_back(agent.key().id());
}

View File

@ -64,7 +64,7 @@ class StringIndexMarisa : public StringIndex {
bool ub_inclusive) override;
const TargetBitmapPtr
PrefixMatch(std::string prefix) override;
PrefixMatch(const std::string_view prefix) override;
std::string
Reverse_Lookup(size_t offset) const override;
@ -78,10 +78,10 @@ class StringIndexMarisa : public StringIndex {
// get str_id by str, if str not found, -1 was returned.
size_t
lookup(const std::string& str);
lookup(const std::string_view str);
std::vector<size_t>
prefix_match(const std::string& prefix);
prefix_match(const std::string_view prefix);
private:
Config config_;

View File

@ -130,7 +130,7 @@ GetIndexModeFromConfig(const Config& config) {
}
IndexMode
GetIndexMode(const std::string index_mode) {
GetIndexMode(const std::string_view index_mode) {
if (index_mode.compare("CPU") == 0 || index_mode.compare("cpu") == 0) {
return IndexMode::MODE_CPU;
}

View File

@ -112,7 +112,7 @@ IndexMode
GetIndexModeFromConfig(const Config& config);
IndexMode
GetIndexMode(const std::string index_mode);
GetIndexMode(const std::string_view index_mode);
storage::FieldDataMeta
GetFieldDataMetaFromConfig(const Config& config);

View File

@ -51,12 +51,12 @@ LogOut(const char* pattern, ...) {
}
void
SetThreadName(const std::string& name) {
SetThreadName(const std::string_view name) {
// Note: the name cannot exceed 16 bytes
#ifdef __APPLE__
pthread_setname_np(name.c_str());
pthread_setname_np(name.data());
#elif defined(__linux__) || defined(__MINGW64__)
pthread_setname_np(pthread_self(), name.c_str());
pthread_setname_np(pthread_self(), name.data());
#else
#error "Unsupported SetThreadName";
#endif

View File

@ -129,7 +129,7 @@ std::string
LogOut(const char* pattern, ...);
void
SetThreadName(const std::string& name);
SetThreadName(const std::string_view name);
std::string
GetThreadName();

View File

@ -28,7 +28,7 @@ namespace milvus::query {
class Parser {
public:
friend std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str);
CreatePlan(const Schema& schema, const std::string_view dsl_str);
private:
std::unique_ptr<Plan>

View File

@ -24,10 +24,10 @@ namespace milvus::query {
// deprecated
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan,
const std::string& placeholder_group_blob) {
const std::string_view placeholder_group_blob) {
return ParsePlaceholderGroup(
plan,
reinterpret_cast<const uint8_t*>(placeholder_group_blob.c_str()),
reinterpret_cast<const uint8_t*>(placeholder_group_blob.data()),
placeholder_group_blob.size());
}
@ -64,7 +64,7 @@ ParsePlaceholderGroup(const Plan* plan,
}
std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str) {
CreatePlan(const Schema& schema, const std::string_view dsl_str) {
Json dsl;
dsl = json::parse(dsl_str);
auto plan = Parser(schema).CreatePlanImpl(dsl);

View File

@ -27,7 +27,7 @@ struct PlaceholderGroup;
struct RetrievePlan;
std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl);
CreatePlan(const Schema& schema, const std::string_view dsl);
// Note: serialized_expr_plan is of binary format
std::unique_ptr<Plan>
@ -43,7 +43,7 @@ ParsePlaceholderGroup(const Plan* plan,
// deprecated
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan,
const std::string& placeholder_group_blob);
const std::string_view placeholder_group_blob);
int64_t
GetNumOfQueries(const PlaceholderGroup*);

View File

@ -11,15 +11,15 @@
#pragma once
#include <functional>
#include <string>
#include "common/Utils.h"
#include "common/VectorTrait.h"
#include "exceptions/EasyAssert.h"
#include "query/Expr.h"
#include "common/Utils.h"
#include "query/Utils.h"
#include <functional>
#include <string>
namespace milvus::query {
template <typename Op, typename T, typename U>
bool

View File

@ -198,7 +198,7 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
Status
DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<
::milvus::grpc::VectorParam>& vector_params,
const std::string& dsl_string,
const std::string_view dsl_string,
query_old::BooleanQueryPtr& boolean_query,
query_old::QueryPtr& query_ptr) {
#if 1
@ -214,7 +214,7 @@ DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<
"DSL must include vector query");
}
for (const auto& vector_param : vector_params) {
const std::string& vector_string = vector_param.json();
const std::string_view vector_string = vector_param.json();
milvus::json vector_json = Json::parse(vector_string);
milvus::json::iterator it = vector_json.begin();
std::string placeholder = it.key();
@ -222,7 +222,7 @@ DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<
auto vector_query = std::make_shared<query_old::VectorQuery>();
milvus::json::iterator vector_param_it = it.value().begin();
if (vector_param_it != it.value().end()) {
const std::string& field_name = vector_param_it.key();
const std::string_view field_name = vector_param_it.key();
vector_query->field_name = field_name;
milvus::json param_json = vector_param_it.value();
int64_t topk = param_json["topk"];

View File

@ -26,7 +26,7 @@ namespace {
Status
CheckParameterRange(const milvus::json& json_params,
const std::string& param_name,
const std::string_view param_name,
int64_t min,
int64_t max,
bool min_close = true,
@ -63,7 +63,7 @@ CheckParameterRange(const milvus::json& json_params,
Status
CheckParameterExistence(const milvus::json& json_params,
const std::string& param_name) {
const std::string_view param_name) {
if (json_params.find(param_name) == json_params.end()) {
std::string msg = "Parameter list must contain: ";
msg += param_name;
@ -92,7 +92,7 @@ CheckParameterExistence(const milvus::json& json_params,
} // namespace
Status
ValidateCollectionName(const std::string& collection_name) {
ValidateCollectionName(const std::string_view collection_name) {
// Collection name shouldn't be empty.
if (collection_name.empty()) {
std::string msg = "Collection name should not be empty.";
@ -138,7 +138,7 @@ ValidateCollectionName(const std::string& collection_name) {
}
Status
ValidateFieldName(const std::string& field_name) {
ValidateFieldName(const std::string_view field_name) {
// Field name shouldn't be empty.
if (field_name.empty()) {
std::string msg = "Field name should not be empty.";
@ -269,7 +269,7 @@ ValidateDimension(int64_t dim, bool is_binary) {
Status
ValidateIndexParams(const milvus::json& index_params,
int64_t dimension,
const std::string& index_type) {
const std::string_view index_type) {
if (engine::utils::IsFlatIndexType(index_type)) {
return Status::OK();
} else if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT ||
@ -362,8 +362,8 @@ ValidateSegmentRowCount(int64_t segment_row_count) {
}
Status
ValidateIndexMetricType(const std::string& metric_type,
const std::string& index_type) {
ValidateIndexMetricType(const std::string_view metric_type,
const std::string_view index_type) {
if (engine::utils::IsFlatIndexType(index_type)) {
// pass
} else if (index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) {
@ -391,7 +391,7 @@ ValidateIndexMetricType(const std::string& metric_type,
}
Status
ValidateSearchMetricType(const std::string& metric_type, bool is_binary) {
ValidateSearchMetricType(const std::string_view metric_type, bool is_binary) {
if (is_binary) {
// binary
if (metric_type == knowhere::Metric::L2 ||
@ -432,7 +432,7 @@ ValidateSearchTopk(int64_t top_k) {
Status
ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
for (const std::string& tag : partition_tags) {
for (const std::string_view tag : partition_tags) {
// Partition nametag shouldn't be empty.
if (tag.empty()) {
std::string msg = "Partition tag should not be empty.";

View File

@ -27,10 +27,10 @@ constexpr int64_t GPU_QUERY_MAX_TOPK = 2048;
constexpr int64_t GPU_QUERY_MAX_NPROBE = 2048;
extern Status
ValidateCollectionName(const std::string& collection_name);
ValidateCollectionName(const std::string_view collection_name);
extern Status
ValidateFieldName(const std::string& field_name);
ValidateFieldName(const std::string_view field_name);
extern Status
ValidateDimension(int64_t dimension, bool is_binary);
@ -44,17 +44,17 @@ ValidateStructuredIndexType(std::string& index_type);
extern Status
ValidateIndexParams(const milvus::json& index_params,
int64_t dimension,
const std::string& index_type);
const std::string_view index_type);
extern Status
ValidateSegmentRowCount(int64_t segment_row_count);
extern Status
ValidateIndexMetricType(const std::string& metric_type,
const std::string& index_type);
ValidateIndexMetricType(const std::string_view metric_type,
const std::string_view index_type);
extern Status
ValidateSearchMetricType(const std::string& metric_type, bool is_binary);
ValidateSearchMetricType(const std::string_view metric_type, bool is_binary);
extern Status
ValidateSearchTopk(int64_t top_k);

View File

@ -16,7 +16,7 @@
namespace milvus::segcore {
Collection::Collection(const std::string& collection_proto)
Collection::Collection(const std::string_view collection_proto)
: schema_proto_(collection_proto) {
parse();
}

View File

@ -20,7 +20,7 @@ namespace milvus::segcore {
class Collection {
public:
explicit Collection(const std::string& collection_proto);
explicit Collection(const std::string_view collection_proto);
void
parse();
@ -31,7 +31,7 @@ class Collection {
return schema_;
}
const std::string&
const std::string_view
get_collection_name() {
return collection_name_;
}

View File

@ -79,7 +79,7 @@ LocalChunkManager::Read(const std::string& filepath,
void* buf,
uint64_t size) {
std::ifstream infile;
infile.open(filepath, std::ios_base::binary);
infile.open(filepath.data(), std::ios_base::binary);
if (infile.fail()) {
std::stringstream err_msg;
err_msg << "Error: open local file '" << filepath << " failed, "
@ -104,7 +104,7 @@ LocalChunkManager::Write(const std::string& absPathStr,
void* buf,
uint64_t size) {
std::ofstream outfile;
outfile.open(absPathStr, std::ios_base::binary);
outfile.open(absPathStr.data(), std::ios_base::binary);
if (outfile.fail()) {
std::stringstream err_msg;
err_msg << "Error: open local file '" << absPathStr << " failed, "
@ -126,7 +126,7 @@ LocalChunkManager::Write(const std::string& absPathStr,
uint64_t size) {
std::ofstream outfile;
outfile.open(
absPathStr,
absPathStr.data(),
std::ios_base::in | std::ios_base::out | std::ios_base::binary);
if (outfile.fail()) {
std::stringstream err_msg;

View File

@ -14,20 +14,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include "storage/MinioChunkManager.h"
#include <aws/core/auth/AWSCredentials.h>
#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/core/auth/STSCredentialsProvider.h>
#include <aws/s3/model/CreateBucketRequest.h>
#include <aws/s3/model/DeleteBucketRequest.h>
#include <aws/s3/model/HeadBucketRequest.h>
#include <aws/s3/model/DeleteObjectRequest.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/HeadBucketRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>
#include <aws/s3/model/ListObjectsRequest.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/PutObjectRequest.h>
#include "storage/MinioChunkManager.h"
#include <fstream>
#include "exceptions/EasyAssert.h"
#include "log/Log.h"

View File

@ -39,11 +39,11 @@ ReleaseArrowUnused() {
}
static const char*
ErrorMsg(const std::string& msg) {
ErrorMsg(const std::string_view msg) {
if (msg.empty())
return nullptr;
auto ret = (char*)malloc(msg.size() + 1);
std::memcpy(ret, msg.c_str(), msg.size());
std::memcpy(ret, msg.data(), msg.size());
ret[msg.size()] = '\0';
return ret;
}

View File

@ -18,7 +18,7 @@ namespace milvus {
constexpr int CODE_WIDTH = sizeof(StatusCode);
Status::Status(StatusCode code, const std::string& msg) {
Status::Status(StatusCode code, const std::string_view msg) {
// 4 bytes store code
// 4 bytes store message length
// the left bytes store message string

View File

@ -30,7 +30,7 @@ using StatusCode = ErrorCode;
class Status {
public:
Status(StatusCode code, const std::string& msg);
Status(StatusCode code, const std::string_view msg);
Status() = default;
virtual ~Status();