mirror of https://github.com/milvus-io/milvus.git
Integrates marisa trie index (#16192)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/16326/head
parent
8b0f260b05
commit
fd589baca7
2
Makefile
2
Makefile
|
@ -210,7 +210,7 @@ docker: install
|
|||
install: all
|
||||
@echo "Installing binary to './bin'"
|
||||
@mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/milvus $(GOPATH)/bin/milvus
|
||||
@mkdir -p $(LIBRARY_PATH) && cp -P $(PWD)/internal/core/output/lib/* $(LIBRARY_PATH)
|
||||
@mkdir -p $(LIBRARY_PATH) && cp -r -P $(PWD)/internal/core/output/lib/* $(LIBRARY_PATH)
|
||||
@echo "Installation successful."
|
||||
|
||||
clean:
|
||||
|
|
|
@ -34,6 +34,7 @@ add_executable(indexbuilder_bench ${indexbuilder_bench_srcs})
|
|||
target_link_libraries(indexbuilder_bench
|
||||
milvus_segcore
|
||||
milvus_indexbuilder
|
||||
milvus_index
|
||||
log
|
||||
pthread
|
||||
knowhere
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/type_c.h"
|
||||
#include <string>
|
||||
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "index/ScalarIndexSort.h"
|
||||
|
||||
namespace milvus::scalar {
|
||||
|
||||
// TODO: optimize here.
|
||||
class BoolIndex : public ScalarIndexSort<bool> {
|
||||
public:
|
||||
void
|
||||
BuildWithDataset(const DatasetPtr& dataset) override {
|
||||
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
|
||||
proto::schema::BoolArray arr;
|
||||
arr.ParseFromArray(data, size);
|
||||
Build(arr.data().size(), arr.data().data());
|
||||
}
|
||||
};
|
||||
using BoolIndexPtr = std::unique_ptr<BoolIndex>;
|
||||
|
||||
inline BoolIndexPtr
|
||||
CreateBoolIndex() {
|
||||
return std::make_unique<BoolIndex>();
|
||||
}
|
||||
} // namespace milvus::scalar
|
|
@ -11,9 +11,21 @@
|
|||
|
||||
aux_source_directory( ${MILVUS_ENGINE_SRC}/index INDEX_FILES )
|
||||
|
||||
add_library( milvus_index STATIC ${INDEX_FILES} )
|
||||
add_library( milvus_index SHARED ${INDEX_FILES} )
|
||||
|
||||
# TODO: support compile marisa on windows and mac.
|
||||
set(PLATFORM_LIBS )
|
||||
if ( LINUX )
|
||||
set(PLATFORM_LIBS marisa)
|
||||
endif()
|
||||
if (MSYS)
|
||||
set(PLATFORM_LIBS -Wl,--allow-multiple-definition)
|
||||
endif ()
|
||||
|
||||
target_link_libraries(milvus_index
|
||||
knowhere
|
||||
milvus_proto
|
||||
knowhere
|
||||
${PLATFORM_LIBS}
|
||||
)
|
||||
|
||||
install(TARGETS milvus_index DESTINATION lib)
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
#include <memory>
|
||||
#include <knowhere/index/Index.h>
|
||||
#include <knowhere/common/Dataset.h>
|
||||
#include <knowhere/index/structured_index_simple/StructuredIndex.h>
|
||||
#include "index/OperatorType.h"
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
|
||||
namespace milvus::scalar {
|
||||
using Index = knowhere::Index;
|
||||
|
@ -22,11 +23,16 @@ using IndexPtr = std::unique_ptr<Index>;
|
|||
using BinarySet = knowhere::BinarySet;
|
||||
using Config = knowhere::Config;
|
||||
using DatasetPtr = knowhere::DatasetPtr;
|
||||
using OperatorType = knowhere::scalar::OperatorType;
|
||||
using TargetBitmap = boost::dynamic_bitset<>;
|
||||
using TargetBitmapPtr = std::unique_ptr<TargetBitmap>;
|
||||
|
||||
class IndexBase : public Index {
|
||||
public:
|
||||
virtual void
|
||||
Build(const DatasetPtr& dataset) = 0;
|
||||
BuildWithDataset(const DatasetPtr& dataset) = 0;
|
||||
|
||||
virtual const TargetBitmapPtr
|
||||
Query(const DatasetPtr& dataset) = 0;
|
||||
};
|
||||
using IndexBasePtr = std::unique_ptr<IndexBase>;
|
||||
|
||||
|
|
|
@ -10,14 +10,32 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <string>
|
||||
#include "ScalarIndexSort.h"
|
||||
#include "index/ScalarIndexSort.h"
|
||||
#include "index/StringIndexMarisa.h"
|
||||
#include "index/IndexType.h"
|
||||
#include "index/BoolIndex.h"
|
||||
|
||||
namespace milvus::scalar {
|
||||
|
||||
template <typename T>
|
||||
inline ScalarIndexPtr<T>
|
||||
IndexFactory::CreateIndex(std::string index_type) {
|
||||
IndexFactory::CreateIndex(const std::string& index_type) {
|
||||
return CreateScalarIndexSort<T>();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline ScalarIndexPtr<bool>
|
||||
IndexFactory::CreateIndex(const std::string& index_type) {
|
||||
return CreateBoolIndex();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline ScalarIndexPtr<std::string>
|
||||
IndexFactory::CreateIndex(const std::string& index_type) {
|
||||
#ifdef __linux__
|
||||
return CreateStringIndexMarisa();
|
||||
#endif
|
||||
throw std::runtime_error("unsupported platform");
|
||||
}
|
||||
|
||||
} // namespace milvus::scalar
|
||||
|
|
|
@ -11,11 +11,12 @@
|
|||
|
||||
#include "index/IndexFactory.h"
|
||||
#include "index/ScalarIndexSort.h"
|
||||
#include "index/StringIndexMarisa.h"
|
||||
|
||||
namespace milvus::scalar {
|
||||
|
||||
IndexBasePtr
|
||||
IndexFactory::CreateIndex(CDataType dtype, std::string index_type) {
|
||||
IndexFactory::CreateIndex(CDataType dtype, const std::string& index_type) {
|
||||
switch (dtype) {
|
||||
case Bool:
|
||||
return CreateIndex<bool>(index_type);
|
||||
|
@ -31,9 +32,11 @@ IndexFactory::CreateIndex(CDataType dtype, std::string index_type) {
|
|||
return CreateIndex<float>(index_type);
|
||||
case Double:
|
||||
return CreateIndex<double>(index_type);
|
||||
|
||||
case String:
|
||||
case VarChar:
|
||||
return CreateIndex<std::string>(index_type);
|
||||
|
||||
case None:
|
||||
case BinaryVector:
|
||||
case FloatVector:
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
#include <utils/Types.h>
|
||||
#include "index/Index.h"
|
||||
#include "common/type_c.h"
|
||||
#include "ScalarIndex.h"
|
||||
#include "index/ScalarIndex.h"
|
||||
#include "index/StringIndex.h"
|
||||
#include <string>
|
||||
|
||||
namespace milvus::scalar {
|
||||
|
@ -35,11 +36,11 @@ class IndexFactory {
|
|||
}
|
||||
|
||||
IndexBasePtr
|
||||
CreateIndex(CDataType dtype, std::string index_type);
|
||||
CreateIndex(CDataType dtype, const std::string& index_type);
|
||||
|
||||
template <typename T>
|
||||
ScalarIndexPtr<T>
|
||||
CreateIndex(std::string index_type);
|
||||
CreateIndex(const std::string& index_type);
|
||||
};
|
||||
|
||||
} // namespace milvus::scalar
|
||||
|
|
|
@ -9,10 +9,8 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "index/Index.h"
|
||||
#pragma once
|
||||
|
||||
namespace milvus::scalar {
|
||||
void
|
||||
dummy() {
|
||||
constexpr const char* INDEX_TYPE_MARISA = "marisa";
|
||||
}
|
||||
} // namespace milvus::scalar
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace milvus::scalar {
|
||||
constexpr const char* OPERATOR_TYPE = "operator_type";
|
||||
constexpr const char* RANGE_VALUE = "range_value";
|
||||
constexpr const char* LOWER_BOUND_VALUE = "lower_bound_value";
|
||||
constexpr const char* LOWER_BOUND_INCLUSIVE = "lower_bound_inclusive";
|
||||
constexpr const char* UPPER_BOUND_VALUE = "upper_bound_value";
|
||||
constexpr const char* UPPER_BOUND_INCLUSIVE = "upper_bound_inclusive";
|
||||
constexpr const char* PREFIX_VALUE = "prefix_value";
|
||||
constexpr const char* MARISA_TRIE = "marisa_trie";
|
||||
// below configurations will be persistent, do not edit them.
|
||||
constexpr const char* MARISA_TRIE_INDEX = "marisa_trie_index";
|
||||
constexpr const char* MARISA_STR_IDS = "marisa_trie_str_ids";
|
||||
constexpr const char* FLAT_STR_INDEX = "flat_str_index";
|
||||
} // namespace milvus::scalar
|
|
@ -11,18 +11,16 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <knowhere/index/structured_index_simple/StructuredIndexSort.h>
|
||||
#include <string>
|
||||
|
||||
namespace milvus::indexbuilder {
|
||||
|
||||
class StringIndexImpl : public knowhere::scalar::StructuredIndexSort<std::string> {
|
||||
public:
|
||||
knowhere::BinarySet
|
||||
Serialize(const knowhere::Config& config) override;
|
||||
|
||||
void
|
||||
Load(const knowhere::BinarySet& index_binary) override;
|
||||
namespace milvus::scalar {
|
||||
enum OperatorType {
|
||||
LT = 0,
|
||||
LE = 1,
|
||||
GT = 3,
|
||||
GE = 4,
|
||||
RangeOp = 5,
|
||||
InOp = 6,
|
||||
NotInOp = 7,
|
||||
PrefixMatchOp = 8,
|
||||
PostfixMatchOp = 9,
|
||||
};
|
||||
|
||||
} // namespace milvus::indexbuilder
|
||||
} // namespace milvus::scalar
|
|
@ -0,0 +1,56 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "index/Meta.h"
|
||||
|
||||
namespace milvus::scalar {
|
||||
template <typename T>
|
||||
const TargetBitmapPtr
|
||||
ScalarIndex<T>::Query(const DatasetPtr& dataset) {
|
||||
auto op = dataset->Get<OperatorType>(OPERATOR_TYPE);
|
||||
switch (op) {
|
||||
case LT:
|
||||
case LE:
|
||||
case GT:
|
||||
case GE: {
|
||||
auto value = dataset->Get<T>(RANGE_VALUE);
|
||||
return Range(value, op);
|
||||
}
|
||||
|
||||
case RangeOp: {
|
||||
auto lower_bound_value = dataset->Get<T>(LOWER_BOUND_VALUE);
|
||||
auto upper_bound_value = dataset->Get<T>(UPPER_BOUND_VALUE);
|
||||
auto lower_bound_inclusive = dataset->Get<bool>(LOWER_BOUND_INCLUSIVE);
|
||||
auto upper_bound_inclusive = dataset->Get<bool>(UPPER_BOUND_INCLUSIVE);
|
||||
return Range(lower_bound_value, lower_bound_inclusive, upper_bound_value, upper_bound_inclusive);
|
||||
}
|
||||
|
||||
case InOp: {
|
||||
auto n = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto values = dataset->Get<const void*>(knowhere::meta::TENSOR);
|
||||
return In(n, reinterpret_cast<const T*>(values));
|
||||
}
|
||||
|
||||
case NotInOp: {
|
||||
auto n = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto values = dataset->Get<const void*>(knowhere::meta::TENSOR);
|
||||
return NotIn(n, reinterpret_cast<const T*>(values));
|
||||
}
|
||||
|
||||
case PrefixMatchOp:
|
||||
case PostfixMatchOp:
|
||||
default:
|
||||
throw std::invalid_argument(std::string("unsupported operator type: " + std::to_string(op)));
|
||||
}
|
||||
}
|
||||
} // namespace milvus::scalar
|
|
@ -19,9 +19,6 @@
|
|||
|
||||
namespace milvus::scalar {
|
||||
|
||||
using TargetBitmap = boost::dynamic_bitset<>;
|
||||
using TargetBitmapPtr = std::unique_ptr<TargetBitmap>;
|
||||
|
||||
template <typename T>
|
||||
class ScalarIndex : public IndexBase {
|
||||
public:
|
||||
|
@ -39,9 +36,14 @@ class ScalarIndex : public IndexBase {
|
|||
|
||||
virtual const TargetBitmapPtr
|
||||
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) = 0;
|
||||
|
||||
const TargetBitmapPtr
|
||||
Query(const DatasetPtr& dataset) override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using ScalarIndexPtr = std::unique_ptr<ScalarIndex<T>>;
|
||||
|
||||
} // namespace milvus::scalar
|
||||
|
||||
#include "index/ScalarIndex-inl.h"
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "Meta.h"
|
||||
|
||||
namespace milvus::scalar {
|
||||
|
||||
|
@ -30,7 +31,7 @@ inline ScalarIndexSort<T>::ScalarIndexSort(const size_t n, const T* values) : is
|
|||
|
||||
template <typename T>
|
||||
inline void
|
||||
ScalarIndexSort<T>::Build(const DatasetPtr& dataset) {
|
||||
ScalarIndexSort<T>::BuildWithDataset(const DatasetPtr& dataset) {
|
||||
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
|
||||
Build(size, reinterpret_cast<const T*>(data));
|
||||
|
@ -197,44 +198,4 @@ ScalarIndexSort<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_
|
|||
return bitset;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndexSort<std::string>::Build(const milvus::scalar::DatasetPtr& dataset) {
|
||||
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
|
||||
proto::schema::StringArray arr;
|
||||
arr.ParseFromArray(data, size);
|
||||
// TODO: optimize here. avoid memory copy.
|
||||
std::vector<std::string> vecs{arr.data().begin(), arr.data().end()};
|
||||
Build(arr.data().size(), vecs.data());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BinarySet
|
||||
ScalarIndexSort<std::string>::Serialize(const Config& config) {
|
||||
BinarySet res_set;
|
||||
auto data = this->GetData();
|
||||
for (const auto& record : data) {
|
||||
auto idx = record.idx_;
|
||||
auto str = record.a_;
|
||||
std::shared_ptr<uint8_t[]> content(new uint8_t[str.length()]);
|
||||
memcpy(content.get(), str.c_str(), str.length());
|
||||
res_set.Append(std::to_string(idx), content, str.length());
|
||||
}
|
||||
return res_set;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndexSort<std::string>::Load(const BinarySet& index_binary) {
|
||||
std::vector<std::string> vecs;
|
||||
|
||||
for (const auto& [k, v] : index_binary.binary_map_) {
|
||||
std::string str(reinterpret_cast<const char*>(v->data.get()), v->size);
|
||||
vecs.emplace_back(str);
|
||||
}
|
||||
|
||||
Build(vecs.size(), vecs.data());
|
||||
}
|
||||
|
||||
} // namespace milvus::scalar
|
||||
|
|
|
@ -24,8 +24,6 @@ namespace milvus::scalar {
|
|||
|
||||
template <typename T>
|
||||
class ScalarIndexSort : public ScalarIndex<T> {
|
||||
static_assert(std::is_fundamental_v<T> || std::is_same_v<T, std::string>);
|
||||
|
||||
public:
|
||||
ScalarIndexSort();
|
||||
ScalarIndexSort(size_t n, const T* values);
|
||||
|
@ -37,7 +35,7 @@ class ScalarIndexSort : public ScalarIndex<T> {
|
|||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
void
|
||||
Build(const DatasetPtr& dataset) override;
|
||||
BuildWithDataset(const DatasetPtr& dataset) override;
|
||||
|
||||
void
|
||||
Build(size_t n, const T* values) override;
|
||||
|
@ -57,6 +55,7 @@ class ScalarIndexSort : public ScalarIndex<T> {
|
|||
const TargetBitmapPtr
|
||||
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override;
|
||||
|
||||
public:
|
||||
const std::vector<IndexStructure<T>>&
|
||||
GetData() {
|
||||
return data_;
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "index/ScalarIndex.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "index/Meta.h"
|
||||
#include <pb/schema.pb.h>
|
||||
|
||||
namespace milvus::scalar {
|
||||
|
||||
class StringIndex : public ScalarIndex<std::string> {
|
||||
public:
|
||||
void
|
||||
BuildWithDataset(const DatasetPtr& dataset) override {
|
||||
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
|
||||
proto::schema::StringArray arr;
|
||||
arr.ParseFromArray(data, size);
|
||||
|
||||
{
|
||||
// TODO: optimize here. avoid memory copy.
|
||||
std::vector<std::string> vecs{arr.data().begin(), arr.data().end()};
|
||||
Build(arr.data().size(), vecs.data());
|
||||
}
|
||||
|
||||
{
|
||||
// TODO: test this way.
|
||||
// auto strs = (const std::string*)arr.data().data();
|
||||
// Build(arr.data().size(), strs);
|
||||
}
|
||||
}
|
||||
|
||||
const TargetBitmapPtr
|
||||
Query(const DatasetPtr& dataset) override {
|
||||
auto op = dataset->Get<OperatorType>(OPERATOR_TYPE);
|
||||
if (op == PrefixMatchOp) {
|
||||
auto prefix = dataset->Get<std::string>(PREFIX_VALUE);
|
||||
return PrefixMatch(prefix);
|
||||
}
|
||||
return ScalarIndex<std::string>::Query(dataset);
|
||||
}
|
||||
|
||||
virtual const TargetBitmapPtr
|
||||
PrefixMatch(std::string prefix) = 0;
|
||||
};
|
||||
using StringIndexPtr = std::unique_ptr<StringIndex>;
|
||||
} // namespace milvus::scalar
|
|
@ -0,0 +1,220 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "index/StringIndexMarisa.h"
|
||||
#include "index/Utils.h"
|
||||
#include "index/Index.h"
|
||||
|
||||
#include <boost/uuid/uuid.hpp>
|
||||
#include <boost/uuid/uuid_io.hpp>
|
||||
#include <boost/uuid/uuid_generators.hpp>
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <fcntl.h>
|
||||
#include <knowhere/common/Utils.h>
|
||||
#include <pb/schema.pb.h>
|
||||
|
||||
namespace milvus::scalar {
|
||||
|
||||
#ifdef __linux__
|
||||
|
||||
int64_t
|
||||
StringIndexMarisa::Size() {
|
||||
return trie_.size();
|
||||
}
|
||||
|
||||
void
|
||||
StringIndexMarisa::Build(size_t n, const std::string* values) {
|
||||
if (built_) {
|
||||
throw std::runtime_error("index has been built");
|
||||
}
|
||||
|
||||
marisa::Keyset keyset;
|
||||
{
|
||||
// fill key set.
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
keyset.push_back(values[i].c_str());
|
||||
}
|
||||
}
|
||||
|
||||
trie_.build(keyset);
|
||||
fill_str_ids(n, values);
|
||||
fill_offsets();
|
||||
|
||||
built_ = true;
|
||||
}
|
||||
|
||||
BinarySet
|
||||
StringIndexMarisa::Serialize(const Config& config) {
|
||||
auto uuid = boost::uuids::random_generator()();
|
||||
auto uuid_string = boost::uuids::to_string(uuid);
|
||||
auto file = std::string("/tmp/") + uuid_string;
|
||||
|
||||
auto fd = open(file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
|
||||
trie_.write(fd);
|
||||
|
||||
auto size = get_file_size(fd);
|
||||
auto buf = new uint8_t[size];
|
||||
|
||||
while (read(fd, buf, size) != size) {
|
||||
lseek(fd, 0, SEEK_SET);
|
||||
}
|
||||
std::shared_ptr<uint8_t[]> index_data(buf);
|
||||
|
||||
close(fd);
|
||||
remove(file.c_str());
|
||||
|
||||
auto str_ids_len = str_ids_.size() * sizeof(size_t);
|
||||
std::shared_ptr<uint8_t[]> str_ids(new uint8_t[str_ids_len]);
|
||||
memcpy(str_ids.get(), str_ids_.data(), str_ids_len);
|
||||
|
||||
BinarySet res_set;
|
||||
res_set.Append(MARISA_TRIE_INDEX, index_data, size);
|
||||
res_set.Append(MARISA_STR_IDS, str_ids, str_ids_len);
|
||||
|
||||
knowhere::Disassemble(4 * 1024 * 1024, res_set);
|
||||
|
||||
return res_set;
|
||||
}
|
||||
|
||||
void
|
||||
StringIndexMarisa::Load(const BinarySet& set) {
|
||||
knowhere::Assemble(const_cast<BinarySet&>(set));
|
||||
|
||||
auto uuid = boost::uuids::random_generator()();
|
||||
auto uuid_string = boost::uuids::to_string(uuid);
|
||||
auto file = std::string("/tmp/") + uuid_string;
|
||||
|
||||
auto index = set.GetByName(MARISA_TRIE_INDEX);
|
||||
auto len = index->size;
|
||||
|
||||
auto fd = open(file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
|
||||
lseek(fd, 0, SEEK_SET);
|
||||
while (write(fd, index->data.get(), len) != len) {
|
||||
lseek(fd, 0, SEEK_SET);
|
||||
}
|
||||
|
||||
lseek(fd, 0, SEEK_SET);
|
||||
trie_.read(fd);
|
||||
close(fd);
|
||||
remove(file.c_str());
|
||||
|
||||
auto str_ids = set.GetByName(MARISA_STR_IDS);
|
||||
auto str_ids_len = str_ids->size;
|
||||
str_ids_.resize(str_ids_len / sizeof(size_t));
|
||||
memcpy(str_ids_.data(), str_ids->data.get(), str_ids_len);
|
||||
|
||||
fill_offsets();
|
||||
}
|
||||
|
||||
const TargetBitmapPtr
|
||||
StringIndexMarisa::In(size_t n, const std::string* values) {
|
||||
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(str_ids_.size());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
auto str = values[i];
|
||||
auto str_id = lookup(str);
|
||||
if (str_id >= 0) {
|
||||
auto offsets = str_ids_to_offsets_[str_id];
|
||||
for (auto offset : offsets) {
|
||||
bitset->set(offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
return bitset;
|
||||
}
|
||||
|
||||
const TargetBitmapPtr
|
||||
StringIndexMarisa::NotIn(size_t n, const std::string* values) {
|
||||
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(str_ids_.size());
|
||||
bitset->set();
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
auto str = values[i];
|
||||
auto str_id = lookup(str);
|
||||
if (str_id >= 0) {
|
||||
auto offsets = str_ids_to_offsets_[str_id];
|
||||
for (auto offset : offsets) {
|
||||
bitset->reset(offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
return bitset;
|
||||
}
|
||||
|
||||
const TargetBitmapPtr
|
||||
StringIndexMarisa::Range(std::string value, OperatorType op) {
|
||||
throw std::runtime_error("todo: unsupported now");
|
||||
}
|
||||
|
||||
const TargetBitmapPtr
|
||||
StringIndexMarisa::Range(std::string lower_bound_value,
|
||||
bool lb_inclusive,
|
||||
std::string upper_bound_value,
|
||||
bool ub_inclusive) {
|
||||
throw std::runtime_error("todo: unsupported now");
|
||||
}
|
||||
|
||||
const TargetBitmapPtr
|
||||
StringIndexMarisa::PrefixMatch(std::string prefix) {
|
||||
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(str_ids_.size());
|
||||
auto matched = prefix_match(prefix);
|
||||
for (const auto str_id : matched) {
|
||||
auto offsets = str_ids_to_offsets_[str_id];
|
||||
for (auto offset : offsets) {
|
||||
bitset->set(offset);
|
||||
}
|
||||
}
|
||||
return bitset;
|
||||
}
|
||||
|
||||
void
|
||||
StringIndexMarisa::fill_str_ids(size_t n, const std::string* values) {
|
||||
str_ids_.resize(n);
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
auto str = values[i];
|
||||
auto str_id = lookup(str);
|
||||
assert(str_id >= 0);
|
||||
str_ids_[i] = str_id;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StringIndexMarisa::fill_offsets() {
|
||||
for (size_t offset = 0; offset < str_ids_.size(); offset++) {
|
||||
auto str_id = str_ids_[offset];
|
||||
if (str_ids_to_offsets_.find(str_id) == str_ids_to_offsets_.end()) {
|
||||
str_ids_to_offsets_[str_id] = std::vector<size_t>{};
|
||||
}
|
||||
str_ids_to_offsets_[str_id].push_back(offset);
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
StringIndexMarisa::lookup(const std::string& str) {
|
||||
marisa::Agent agent;
|
||||
agent.set_query(str.c_str());
|
||||
trie_.lookup(agent);
|
||||
return agent.key().id();
|
||||
}
|
||||
|
||||
std::vector<size_t>
|
||||
StringIndexMarisa::prefix_match(const std::string& prefix) {
|
||||
std::vector<size_t> ret;
|
||||
marisa::Agent agent;
|
||||
agent.set_query(prefix.c_str());
|
||||
while (trie_.predictive_search(agent)) {
|
||||
ret.push_back(agent.key().id());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace milvus::scalar
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __linux__
|
||||
|
||||
#include <marisa.h>
|
||||
#include "index/StringIndex.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
namespace milvus::scalar {
|
||||
|
||||
class StringIndexMarisa : public StringIndex {
|
||||
public:
|
||||
StringIndexMarisa() = default;
|
||||
|
||||
int64_t
|
||||
Size() override;
|
||||
|
||||
BinarySet
|
||||
Serialize(const Config& config) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& set) override;
|
||||
|
||||
void
|
||||
Build(size_t n, const std::string* values) override;
|
||||
|
||||
const TargetBitmapPtr
|
||||
In(size_t n, const std::string* values) override;
|
||||
|
||||
const TargetBitmapPtr
|
||||
NotIn(size_t n, const std::string* values) override;
|
||||
|
||||
const TargetBitmapPtr
|
||||
Range(std::string value, OperatorType op) override;
|
||||
|
||||
const TargetBitmapPtr
|
||||
Range(std::string lower_bound_value, bool lb_inclusive, std::string upper_bound_value, bool ub_inclusive) override;
|
||||
|
||||
const TargetBitmapPtr
|
||||
PrefixMatch(std::string prefix) override;
|
||||
|
||||
private:
|
||||
void
|
||||
fill_str_ids(size_t n, const std::string* values);
|
||||
|
||||
void
|
||||
fill_offsets();
|
||||
|
||||
// get str_id by str, if str not found, -1 was returned.
|
||||
size_t
|
||||
lookup(const std::string& str);
|
||||
|
||||
std::vector<size_t>
|
||||
prefix_match(const std::string& prefix);
|
||||
|
||||
private:
|
||||
marisa::Trie trie_;
|
||||
std::vector<size_t> str_ids_; // used to retrieve.
|
||||
std::map<size_t, std::vector<size_t>> str_ids_to_offsets_;
|
||||
bool built_ = false;
|
||||
};
|
||||
|
||||
using StringIndexMarisaPtr = std::unique_ptr<StringIndexMarisa>;
|
||||
|
||||
inline StringIndexPtr
|
||||
CreateStringIndexMarisa() {
|
||||
return std::make_unique<StringIndexMarisa>();
|
||||
}
|
||||
|
||||
} // namespace milvus::scalar
|
||||
|
||||
#endif
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <vector>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <iostream>
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
|
||||
namespace milvus::scalar {
|
||||
|
||||
size_t
|
||||
get_file_size(int fd) {
|
||||
struct stat s;
|
||||
fstat(fd, &s);
|
||||
return s.st_size;
|
||||
}
|
||||
|
||||
} // namespace milvus::scalar
|
|
@ -15,7 +15,7 @@ set(INDEXBUILDER_FILES
|
|||
index_c.cpp
|
||||
init_c.cpp
|
||||
utils.cpp
|
||||
StringIndexImpl.cpp
|
||||
ScalarIndexCreator.cpp
|
||||
)
|
||||
add_library(milvus_indexbuilder SHARED
|
||||
${INDEXBUILDER_FILES}
|
||||
|
@ -29,6 +29,7 @@ endif ()
|
|||
|
||||
# link order matters
|
||||
target_link_libraries(milvus_indexbuilder
|
||||
milvus_index
|
||||
milvus_common
|
||||
knowhere
|
||||
${TBB}
|
||||
|
|
|
@ -27,6 +27,7 @@ class IndexCreatorBase {
|
|||
virtual knowhere::BinarySet
|
||||
Serialize() = 0;
|
||||
|
||||
// used for test.
|
||||
virtual void
|
||||
Load(const knowhere::BinarySet&) = 0;
|
||||
|
||||
|
|
|
@ -44,32 +44,24 @@ class IndexFactory {
|
|||
auto invalid_dtype_msg = std::string("invalid data type: ") + std::to_string(real_dtype);
|
||||
|
||||
switch (real_dtype) {
|
||||
case milvus::proto::schema::Bool:
|
||||
return std::make_unique<ScalarIndexCreator<bool>>(type_params, index_params);
|
||||
case milvus::proto::schema::Int8:
|
||||
return std::make_unique<ScalarIndexCreator<int8_t>>(type_params, index_params);
|
||||
case milvus::proto::schema::Int16:
|
||||
return std::make_unique<ScalarIndexCreator<int16_t>>(type_params, index_params);
|
||||
case milvus::proto::schema::Int32:
|
||||
return std::make_unique<ScalarIndexCreator<int32_t>>(type_params, index_params);
|
||||
case milvus::proto::schema::Int64:
|
||||
return std::make_unique<ScalarIndexCreator<int64_t>>(type_params, index_params);
|
||||
case milvus::proto::schema::Float:
|
||||
return std::make_unique<ScalarIndexCreator<float_t>>(type_params, index_params);
|
||||
case milvus::proto::schema::Double:
|
||||
return std::make_unique<ScalarIndexCreator<double_t>>(type_params, index_params);
|
||||
|
||||
case proto::schema::Bool:
|
||||
case proto::schema::Int8:
|
||||
case proto::schema::Int16:
|
||||
case proto::schema::Int32:
|
||||
case proto::schema::Int64:
|
||||
case proto::schema::Float:
|
||||
case proto::schema::Double:
|
||||
case proto::schema::VarChar:
|
||||
case milvus::proto::schema::String:
|
||||
return std::make_unique<ScalarIndexCreator<std::string>>(type_params, index_params);
|
||||
case proto::schema::String:
|
||||
return CreateScalarIndex(dtype, type_params, index_params);
|
||||
|
||||
case milvus::proto::schema::BinaryVector:
|
||||
case milvus::proto::schema::FloatVector:
|
||||
case proto::schema::BinaryVector:
|
||||
case proto::schema::FloatVector:
|
||||
return std::make_unique<VecIndexCreator>(type_params, index_params);
|
||||
|
||||
case milvus::proto::schema::None:
|
||||
case milvus::proto::schema::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
|
||||
case milvus::proto::schema::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
|
||||
case proto::schema::None:
|
||||
case proto::schema::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
|
||||
case proto::schema::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
|
||||
default:
|
||||
throw std::invalid_argument(invalid_dtype_msg);
|
||||
}
|
||||
|
|
|
@ -1,84 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <knowhere/index/structured_index_simple/StructuredIndexSort.h>
|
||||
#include <pb/schema.pb.h>
|
||||
#include "indexbuilder/helper.h"
|
||||
#include "indexbuilder/StringIndexImpl.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace milvus::indexbuilder {
|
||||
|
||||
template <typename T>
|
||||
inline ScalarIndexCreator<T>::ScalarIndexCreator(const char* type_params, const char* index_params) {
|
||||
// TODO: move parse-related logic to a common interface.
|
||||
Helper::ParseFromString(type_params_, std::string(type_params));
|
||||
Helper::ParseFromString(index_params_, std::string(index_params));
|
||||
// TODO: create index according to the params.
|
||||
index_ = std::make_unique<knowhere::scalar::StructuredIndexSort<T>>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
ScalarIndexCreator<T>::Build(const knowhere::DatasetPtr& dataset) {
|
||||
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
|
||||
index_->Build(size, reinterpret_cast<const T*>(data));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline knowhere::BinarySet
|
||||
ScalarIndexCreator<T>::Serialize() {
|
||||
return index_->Serialize(config_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
ScalarIndexCreator<T>::Load(const knowhere::BinarySet& binary_set) {
|
||||
index_->Load(binary_set);
|
||||
}
|
||||
|
||||
// not sure that the pointer of a golang bool array acts like other types.
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndexCreator<bool>::Build(const knowhere::DatasetPtr& dataset) {
|
||||
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
|
||||
proto::schema::BoolArray arr;
|
||||
Helper::ParseParams(arr, data, size);
|
||||
index_->Build(arr.data().size(), arr.data().data());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline ScalarIndexCreator<std::string>::ScalarIndexCreator(const char* type_params, const char* index_params) {
|
||||
// TODO: move parse-related logic to a common interface.
|
||||
Helper::ParseFromString(type_params_, std::string(type_params));
|
||||
Helper::ParseFromString(index_params_, std::string(index_params));
|
||||
// TODO: create index according to the params.
|
||||
index_ = std::make_unique<StringIndexImpl>();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndexCreator<std::string>::Build(const knowhere::DatasetPtr& dataset) {
|
||||
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
|
||||
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
|
||||
proto::schema::StringArray arr;
|
||||
Helper::ParseParams(arr, data, size);
|
||||
// TODO: optimize here. avoid memory copy.
|
||||
std::vector<std::string> vecs{arr.data().begin(), arr.data().end()};
|
||||
index_->Build(arr.data().size(), vecs.data());
|
||||
}
|
||||
|
||||
} // namespace milvus::indexbuilder
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "indexbuilder/helper.h"
|
||||
#include "indexbuilder/ScalarIndexCreator.h"
|
||||
#include "index/IndexFactory.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace milvus::indexbuilder {
|
||||
|
||||
ScalarIndexCreator::ScalarIndexCreator(CDataType dtype, const char* type_params, const char* index_params) {
|
||||
dtype_ = dtype;
|
||||
// TODO: move parse-related logic to a common interface.
|
||||
Helper::ParseFromString(type_params_, std::string(type_params));
|
||||
Helper::ParseFromString(index_params_, std::string(index_params));
|
||||
// TODO: create index according to the params.
|
||||
index_ = scalar::IndexFactory::GetInstance().CreateIndex(dtype_, index_type());
|
||||
}
|
||||
|
||||
void
|
||||
ScalarIndexCreator::Build(const knowhere::DatasetPtr& dataset) {
|
||||
index_->BuildWithDataset(dataset);
|
||||
}
|
||||
|
||||
knowhere::BinarySet
|
||||
ScalarIndexCreator::Serialize() {
|
||||
return index_->Serialize(config_);
|
||||
}
|
||||
|
||||
void
|
||||
ScalarIndexCreator::Load(const knowhere::BinarySet& binary_set) {
|
||||
index_->Load(binary_set);
|
||||
}
|
||||
|
||||
std::string
|
||||
ScalarIndexCreator::index_type() {
|
||||
// TODO
|
||||
return "sort";
|
||||
}
|
||||
|
||||
} // namespace milvus::indexbuilder
|
|
@ -12,22 +12,18 @@
|
|||
#pragma once
|
||||
|
||||
#include "indexbuilder/IndexCreatorBase.h"
|
||||
#include "knowhere/index/structured_index_simple/StructuredIndex.h"
|
||||
#include "pb/index_cgo_msg.pb.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <common/CDataType.h>
|
||||
#include "index/Index.h"
|
||||
#include "index/ScalarIndex.h"
|
||||
|
||||
namespace milvus::indexbuilder {
|
||||
|
||||
template <typename T>
|
||||
class ScalarIndexCreator : public IndexCreatorBase {
|
||||
// of course, maybe we can support combination index later.
|
||||
// for example, we can create index for combination of (field a, field b),
|
||||
// attribute filtering on the combination can be speed up.
|
||||
static_assert(std::is_fundamental_v<T> || std::is_same_v<T, std::string>);
|
||||
|
||||
public:
|
||||
ScalarIndexCreator(const char* type_params, const char* index_params);
|
||||
ScalarIndexCreator(CDataType dtype, const char* type_params, const char* index_params);
|
||||
|
||||
void
|
||||
Build(const knowhere::DatasetPtr& dataset) override;
|
||||
|
@ -39,11 +35,22 @@ class ScalarIndexCreator : public IndexCreatorBase {
|
|||
Load(const knowhere::BinarySet&) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<knowhere::scalar::StructuredIndex<T>> index_ = nullptr;
|
||||
std::string
|
||||
index_type();
|
||||
|
||||
private:
|
||||
scalar::IndexBasePtr index_ = nullptr;
|
||||
proto::indexcgo::TypeParams type_params_;
|
||||
proto::indexcgo::IndexParams index_params_;
|
||||
knowhere::Config config_;
|
||||
CDataType dtype_;
|
||||
};
|
||||
} // namespace milvus::indexbuilder
|
||||
|
||||
#include "ScalarIndexCreator-inl.h"
|
||||
using ScalarIndexCreatorPtr = std::unique_ptr<ScalarIndexCreator>;
|
||||
|
||||
inline ScalarIndexCreatorPtr
|
||||
CreateScalarIndex(CDataType dtype, const char* type_params, const char* index_params) {
|
||||
return std::make_unique<ScalarIndexCreator>(dtype, type_params, index_params);
|
||||
}
|
||||
|
||||
} // namespace milvus::indexbuilder
|
||||
|
|
|
@ -1,44 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "StringIndexImpl.h"
|
||||
|
||||
namespace milvus::indexbuilder {
|
||||
|
||||
// TODO: optimize here.
|
||||
|
||||
knowhere::BinarySet
|
||||
StringIndexImpl::Serialize(const knowhere::Config& config) {
|
||||
knowhere::BinarySet res_set;
|
||||
auto data = this->GetData();
|
||||
for (const auto& record : data) {
|
||||
auto idx = record.idx_;
|
||||
auto str = record.a_;
|
||||
std::shared_ptr<uint8_t[]> content(new uint8_t[str.length()]);
|
||||
memcpy(content.get(), str.c_str(), str.length());
|
||||
res_set.Append(std::to_string(idx), content, str.length());
|
||||
}
|
||||
return res_set;
|
||||
}
|
||||
|
||||
void
|
||||
StringIndexImpl::Load(const knowhere::BinarySet& index_binary) {
|
||||
std::vector<std::string> vecs;
|
||||
|
||||
for (const auto& [k, v] : index_binary.binary_map_) {
|
||||
std::string str(reinterpret_cast<const char*>(v->data.get()), v->size);
|
||||
vecs.emplace_back(str);
|
||||
}
|
||||
|
||||
Build(vecs.size(), vecs.data());
|
||||
}
|
||||
|
||||
} // namespace milvus::indexbuilder
|
|
@ -31,4 +31,4 @@ set(MILVUS_QUERY_SRCS
|
|||
PlanProto.cpp
|
||||
)
|
||||
add_library(milvus_query ${MILVUS_QUERY_SRCS})
|
||||
target_link_libraries(milvus_query milvus_proto milvus_utils knowhere boost_bitset_ext)
|
||||
target_link_libraries(milvus_query milvus_index milvus_utils milvus_proto knowhere boost_bitset_ext)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <index/ScalarIndexSort.h>
|
||||
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/Span.h"
|
||||
|
@ -20,9 +21,9 @@
|
|||
namespace milvus::query {
|
||||
|
||||
template <typename T>
|
||||
inline std::unique_ptr<knowhere::scalar::StructuredIndex<T>>
|
||||
inline scalar::ScalarIndexPtr<T>
|
||||
generate_scalar_index(Span<T> data) {
|
||||
auto indexing = std::make_unique<knowhere::scalar::StructuredIndexSort<T>>();
|
||||
auto indexing = std::make_unique<scalar::ScalarIndexSort<T>>();
|
||||
indexing->Build(data.row_count(), data.data());
|
||||
return indexing;
|
||||
}
|
||||
|
|
|
@ -149,7 +149,7 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldOffset field_offset, IndexFunc index_
|
|||
auto num_chunk = upper_div(row_count_, size_per_chunk);
|
||||
std::deque<BitsetType> results;
|
||||
|
||||
using Index = knowhere::scalar::StructuredIndex<T>;
|
||||
using Index = scalar::ScalarIndex<T>;
|
||||
for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
|
||||
const Index& indexing = segment_.chunk_scalar_index<T>(field_offset, chunk_id);
|
||||
// NOTE: knowhere is not const-ready
|
||||
|
@ -180,8 +180,8 @@ template <typename T>
|
|||
auto
|
||||
ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType {
|
||||
auto& expr = static_cast<UnaryRangeExprImpl<T>&>(expr_raw);
|
||||
using Index = knowhere::scalar::StructuredIndex<T>;
|
||||
using Operator = knowhere::scalar::OperatorType;
|
||||
using Index = scalar::ScalarIndex<T>;
|
||||
using Operator = scalar::OperatorType;
|
||||
auto op = expr.op_type_;
|
||||
auto val = expr.value_;
|
||||
switch (op) {
|
||||
|
@ -228,7 +228,7 @@ template <typename T>
|
|||
auto
|
||||
ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType {
|
||||
auto& expr = static_cast<BinaryRangeExprImpl<T>&>(expr_raw);
|
||||
using Index = knowhere::scalar::StructuredIndex<T>;
|
||||
using Index = scalar::ScalarIndex<T>;
|
||||
using Operator = knowhere::scalar::OperatorType;
|
||||
bool lower_inclusive = expr.lower_inclusive_;
|
||||
bool upper_inclusive = expr.upper_inclusive_;
|
||||
|
|
|
@ -44,13 +44,12 @@ set(PLATFORM_LIBS )
|
|||
endif ()
|
||||
|
||||
target_link_libraries(milvus_segcore
|
||||
milvus_query
|
||||
milvus_common
|
||||
${PLATFORM_LIBS}
|
||||
pthread
|
||||
${TBB}
|
||||
${OpenMP_CXX_FLAGS}
|
||||
milvus_common
|
||||
knowhere
|
||||
milvus_query
|
||||
# gperftools
|
||||
)
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <index/ScalarIndexSort.h>
|
||||
|
||||
#include "common/SystemProperty.h"
|
||||
#if defined(__linux__) || defined(__MINGW64__)
|
||||
|
@ -116,7 +117,7 @@ ScalarFieldIndexing<T>::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const
|
|||
const auto& chunk = source->get_chunk(chunk_id);
|
||||
// build index for chunk
|
||||
// TODO
|
||||
auto indexing = std::make_unique<knowhere::scalar::StructuredIndexSort<T>>();
|
||||
auto indexing = scalar::CreateScalarIndexSort<T>();
|
||||
indexing->Build(vec_base->get_size_per_chunk(), chunk.data());
|
||||
data_[chunk_id] = std::move(indexing);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#include <memory>
|
||||
|
||||
#include <tbb/concurrent_vector.h>
|
||||
#include <index/Index.h>
|
||||
#include <index/ScalarIndex.h>
|
||||
|
||||
#include "AckResponder.h"
|
||||
#include "InsertRecord.h"
|
||||
|
@ -70,14 +72,14 @@ class ScalarFieldIndexing : public FieldIndexing {
|
|||
BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) override;
|
||||
|
||||
// concurrent
|
||||
knowhere::scalar::StructuredIndex<T>*
|
||||
scalar::ScalarIndex<T>*
|
||||
get_chunk_indexing(int64_t chunk_id) const override {
|
||||
Assert(!field_meta_.is_vector());
|
||||
return data_.at(chunk_id).get();
|
||||
}
|
||||
|
||||
private:
|
||||
tbb::concurrent_vector<std::unique_ptr<knowhere::scalar::StructuredIndex<T>>> data_;
|
||||
tbb::concurrent_vector<scalar::ScalarIndexPtr<T>> data_;
|
||||
};
|
||||
|
||||
class VectorFieldIndexing : public FieldIndexing {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <index/ScalarIndex.h>
|
||||
|
||||
#include "FieldIndexing.h"
|
||||
#include "common/Schema.h"
|
||||
|
@ -74,10 +75,10 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
const knowhere::scalar::StructuredIndex<T>&
|
||||
const scalar::ScalarIndex<T>&
|
||||
chunk_scalar_index(FieldOffset field_offset, int64_t chunk_id) const {
|
||||
static_assert(IsScalar<T>);
|
||||
using IndexType = knowhere::scalar::StructuredIndex<T>;
|
||||
using IndexType = scalar::ScalarIndex<T>;
|
||||
auto base_ptr = chunk_index_impl(field_offset, chunk_id);
|
||||
auto ptr = dynamic_cast<const IndexType*>(base_ptr);
|
||||
AssertInfo(ptr, "entry mismatch");
|
||||
|
|
|
@ -60,4 +60,10 @@ if ( MILVUS_WITH_OPENTRACING )
|
|||
endif()
|
||||
|
||||
add_subdirectory( protobuf )
|
||||
add_subdirectory( boost_ext )
|
||||
add_subdirectory( boost_ext )
|
||||
|
||||
# ******************************* Thridparty marisa ********************************
|
||||
# TODO: support apple & win.
|
||||
if ( LINUX )
|
||||
add_subdirectory( marisa )
|
||||
endif()
|
|
@ -0,0 +1,69 @@
|
|||
set( MARISA_VERSION "v0.2.6")
|
||||
set( MARISA_MD5 "695cecf504ced27ac13aa33d97d69dd0")
|
||||
|
||||
if ( DEFINED ENV{MILVUS_MARISA_URL} )
|
||||
set( MARISA_SOURCE_URL "$ENV{MILVUS_MARISA_URL}" )
|
||||
else ()
|
||||
set( MARISA_SOURCE_URL
|
||||
"https://github.com/s-yata/marisa-trie/archive/refs/tags/${MARISA_VERSION}.tar.gz" )
|
||||
endif ()
|
||||
|
||||
macro(build_marisa)
|
||||
message(STATUS "Building marisa-${MARISA_VERSION} from source")
|
||||
|
||||
set (MARISA_INSTALL_PREFIX ${CMAKE_INSTALL_PREFIX})
|
||||
set (MARISA_DIR "${CMAKE_CURRENT_BINARY_DIR}/src")
|
||||
set (MARISA_CONFIGURE_COMMAND cd ${MARISA_DIR} && libtoolize && autoreconf -i && ./configure --prefix=${MARISA_INSTALL_PREFIX})
|
||||
set (MARISA_BUILD_COMMAND make -j)
|
||||
set (MARISA_INSTALL_COMMAND make install)
|
||||
|
||||
message(${MARISA_DIR})
|
||||
|
||||
externalproject_add(marisa_ep
|
||||
URL ${MARISA_SOURCE_URL}
|
||||
URL_MD5 ${MARISA_MD5}
|
||||
SOURCE_DIR ${MARISA_DIR}
|
||||
BUILD_IN_SOURCE 1
|
||||
CONFIGURE_COMMAND ${MARISA_CONFIGURE_COMMAND}
|
||||
BUILD_COMMAND ${MARISA_BUILD_COMMAND}
|
||||
INSTALL_COMMAND ${MARISA_INSTALL_COMMAND}
|
||||
)
|
||||
|
||||
if( NOT IS_DIRECTORY ${MARISA_INSTALL_PREFIX}/include )
|
||||
file( MAKE_DIRECTORY "${MARISA_INSTALL_PREFIX}/include" )
|
||||
endif()
|
||||
|
||||
add_library(marisa SHARED IMPORTED)
|
||||
set_target_properties( marisa
|
||||
PROPERTIES
|
||||
IMPORTED_GLOBAL TRUE
|
||||
IMPORTED_LOCATION ${MARISA_INSTALL_PREFIX}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}marisa${CMAKE_SHARED_LIBRARY_SUFFIX}
|
||||
INTERFACE_INCLUDE_DIRECTORIES ${MARISA_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
get_target_property(MARISA_IMPORTED_LOCATION marisa IMPORTED_LOCATION)
|
||||
get_target_property(MARISA_INTERFACE_INCLUDE_DIRECTORIES marisa INTERFACE_INCLUDE_DIRECTORIES)
|
||||
message("MARISA_INSTALL_PREFIX: ${MARISA_INSTALL_PREFIX}")
|
||||
message("CMAKE_INSTALL_LIBDIR: ${CMAKE_INSTALL_LIBDIR}")
|
||||
message("MARISA_IMPORTED_LOCATION: ${MARISA_IMPORTED_LOCATION}")
|
||||
message("MARISA_INTERFACE_INCLUDE_DIRECTORIES: ${MARISA_INTERFACE_INCLUDE_DIRECTORIES}")
|
||||
|
||||
add_dependencies(marisa marisa_ep)
|
||||
endmacro()
|
||||
|
||||
set(MARISA_SOURCE "AUTO")
|
||||
if (MARISA_SOURCE STREQUAL "AUTO")
|
||||
find_package(marisa)
|
||||
message(STATUS "marisa libraries: ${MARISA_LIBRARIES}")
|
||||
message(STATUS "marisa found: ${MARISA_FOUND}")
|
||||
|
||||
if (MARISA_FOUND)
|
||||
add_library(marisa)
|
||||
else()
|
||||
build_marisa()
|
||||
endif()
|
||||
elseif (MARISA_SOURCE STREQUAL "BUNDLED")
|
||||
build_marisa()
|
||||
elseif (MARISA_SOURCE STREQUAL "SYSTEM")
|
||||
find_package(marisa)
|
||||
add_library(marisa)
|
||||
endif ()
|
|
@ -41,7 +41,9 @@ if (LINUX)
|
|||
test_utils.cpp
|
||||
test_scalar_index_creator.cpp
|
||||
test_index_c_api.cpp
|
||||
test_index.cpp
|
||||
test_scalar_index.cpp
|
||||
test_string_index.cpp
|
||||
test_bool_index.cpp
|
||||
)
|
||||
# check if memory leak exists in index builder
|
||||
set(INDEX_BUILDER_TEST_FILES
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <pb/schema.pb.h>
|
||||
#include <index/BoolIndex.h>
|
||||
#include "test_utils/indexbuilder_test_utils.h"
|
||||
|
||||
class BoolIndexTest : public ::testing::Test {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
n = 8;
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
*(all_true.mutable_data()->Add()) = true;
|
||||
*(all_false.mutable_data()->Add()) = false;
|
||||
*(half.mutable_data()->Add()) = (i % 2) == 0;
|
||||
}
|
||||
|
||||
all_true_ds = GenDsFromPB(all_true);
|
||||
all_false_ds = GenDsFromPB(all_false);
|
||||
half_ds = GenDsFromPB(half);
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
delete[](char*)(all_true_ds->Get<const void*>(knowhere::meta::TENSOR));
|
||||
delete[](char*) all_false_ds->Get<const void*>(knowhere::meta::TENSOR);
|
||||
delete[](char*) half_ds->Get<const void*>(knowhere::meta::TENSOR);
|
||||
}
|
||||
|
||||
protected:
|
||||
schemapb::BoolArray all_true;
|
||||
schemapb::BoolArray all_false;
|
||||
schemapb::BoolArray half;
|
||||
knowhere::DatasetPtr all_true_ds;
|
||||
knowhere::DatasetPtr all_false_ds;
|
||||
knowhere::DatasetPtr half_ds;
|
||||
size_t n;
|
||||
std::vector<ScalarTestParams> params;
|
||||
};
|
||||
|
||||
TEST_F(BoolIndexTest, Constructor) {
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
}
|
||||
|
||||
TEST_F(BoolIndexTest, In) {
|
||||
auto true_test = std::make_unique<bool>(true);
|
||||
auto false_test = std::make_unique<bool>(false);
|
||||
|
||||
{
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
index->BuildWithDataset(all_true_ds);
|
||||
|
||||
auto bitset1 = index->In(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
|
||||
auto bitset2 = index->In(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->none());
|
||||
}
|
||||
|
||||
{
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
index->BuildWithDataset(all_false_ds);
|
||||
|
||||
auto bitset1 = index->In(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
|
||||
auto bitset2 = index->In(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
}
|
||||
|
||||
{
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
index->BuildWithDataset(half_ds);
|
||||
|
||||
auto bitset1 = index->In(1, true_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset1->test(i), (i % 2) == 0);
|
||||
}
|
||||
|
||||
auto bitset2 = index->In(1, false_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset2->test(i), (i % 2) != 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(BoolIndexTest, NotIn) {
|
||||
auto true_test = std::make_unique<bool>(true);
|
||||
auto false_test = std::make_unique<bool>(false);
|
||||
|
||||
{
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
index->BuildWithDataset(all_true_ds);
|
||||
|
||||
auto bitset1 = index->NotIn(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
|
||||
auto bitset2 = index->NotIn(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
}
|
||||
|
||||
{
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
index->BuildWithDataset(all_false_ds);
|
||||
|
||||
auto bitset1 = index->NotIn(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
|
||||
auto bitset2 = index->NotIn(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->none());
|
||||
}
|
||||
|
||||
{
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
index->BuildWithDataset(half_ds);
|
||||
|
||||
auto bitset1 = index->NotIn(1, true_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset1->test(i), (i % 2) != 0);
|
||||
}
|
||||
|
||||
auto bitset2 = index->NotIn(1, false_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset2->test(i), (i % 2) == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(BoolIndexTest, Codec) {
|
||||
auto true_test = std::make_unique<bool>(true);
|
||||
auto false_test = std::make_unique<bool>(false);
|
||||
|
||||
{
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
index->BuildWithDataset(all_true_ds);
|
||||
|
||||
auto copy_index = milvus::scalar::CreateBoolIndex();
|
||||
copy_index->Load(index->Serialize(nullptr));
|
||||
|
||||
auto bitset1 = copy_index->NotIn(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
|
||||
auto bitset2 = copy_index->NotIn(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
}
|
||||
|
||||
{
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
index->BuildWithDataset(all_false_ds);
|
||||
|
||||
auto copy_index = milvus::scalar::CreateBoolIndex();
|
||||
copy_index->Load(index->Serialize(nullptr));
|
||||
|
||||
auto bitset1 = copy_index->NotIn(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
|
||||
auto bitset2 = copy_index->NotIn(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->none());
|
||||
}
|
||||
|
||||
{
|
||||
auto index = milvus::scalar::CreateBoolIndex();
|
||||
index->BuildWithDataset(half_ds);
|
||||
|
||||
auto copy_index = milvus::scalar::CreateBoolIndex();
|
||||
copy_index->Load(index->Serialize(nullptr));
|
||||
|
||||
auto bitset1 = copy_index->NotIn(1, true_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset1->test(i), (i % 2) != 0);
|
||||
}
|
||||
|
||||
auto bitset2 = copy_index->NotIn(1, false_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset2->test(i), (i % 2) == 0);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -9,134 +9,20 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
#include <knowhere/index/vector_index/ConfAdapterMgr.h>
|
||||
#include <knowhere/archive/KnowhereConfig.h>
|
||||
#include "pb/index_cgo_msg.pb.h"
|
||||
|
||||
#define private public
|
||||
|
||||
#include "index/IndexFactory.h"
|
||||
#include "index/Index.h"
|
||||
#include "index/ScalarIndex.h"
|
||||
#include "index/ScalarIndexSort.h"
|
||||
#include "common/CDataType.h"
|
||||
#include "test_utils/indexbuilder_test_utils.h"
|
||||
#include "test_utils/AssertUtils.h"
|
||||
|
||||
constexpr int64_t nb = 100;
|
||||
namespace indexcgo = milvus::proto::indexcgo;
|
||||
namespace schemapb = milvus::proto::schema;
|
||||
using milvus::scalar::ScalarIndexPtr;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
inline std::vector<std::string>
|
||||
GetIndexTypes() {
|
||||
return std::vector<std::string>{"inverted_index"};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<std::string>
|
||||
GetIndexTypes<std::string>() {
|
||||
return std::vector<std::string>{"marisa-trie"};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
assert_in(const ScalarIndexPtr<T>& index, const std::vector<T>& arr) {
|
||||
// hard to compare floating point value.
|
||||
if (std::is_floating_point_v<T>) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto bitset1 = index->In(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
|
||||
auto bitset2 = index->In(1, test.get());
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->none());
|
||||
}
|
||||
template <typename T>
|
||||
inline void
|
||||
assert_not_in(const ScalarIndexPtr<T>& index, const std::vector<T>& arr) {
|
||||
auto bitset1 = index->NotIn(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
|
||||
auto bitset2 = index->NotIn(1, test.get());
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
assert_range(const ScalarIndexPtr<T>& index, const std::vector<T>& arr) {
|
||||
auto test_min = arr[0];
|
||||
auto test_max = arr[arr.size() - 1];
|
||||
|
||||
auto bitset1 = index->Range(test_min - 1, OperatorType::GT);
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
|
||||
auto bitset2 = index->Range(test_min, OperatorType::GE);
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
|
||||
auto bitset3 = index->Range(test_max + 1, OperatorType::LT);
|
||||
ASSERT_EQ(arr.size(), bitset3->size());
|
||||
ASSERT_TRUE(bitset3->any());
|
||||
|
||||
auto bitset4 = index->Range(test_max, OperatorType::LE);
|
||||
ASSERT_EQ(arr.size(), bitset4->size());
|
||||
ASSERT_TRUE(bitset4->any());
|
||||
|
||||
auto bitset5 = index->Range(test_min, true, test_max, true);
|
||||
ASSERT_EQ(arr.size(), bitset5->size());
|
||||
ASSERT_TRUE(bitset5->any());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
assert_in(const ScalarIndexPtr<std::string>& index, const std::vector<std::string>& arr) {
|
||||
auto bitset1 = index->In(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
assert_not_in(const ScalarIndexPtr<std::string>& index, const std::vector<std::string>& arr) {
|
||||
auto bitset1 = index->NotIn(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
assert_range(const ScalarIndexPtr<std::string>& index, const std::vector<std::string>& arr) {
|
||||
auto test_min = arr[0];
|
||||
auto test_max = arr[arr.size() - 1];
|
||||
|
||||
auto bitset2 = index->Range(test_min, OperatorType::GE);
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
|
||||
auto bitset4 = index->Range(test_max, OperatorType::LE);
|
||||
ASSERT_EQ(arr.size(), bitset4->size());
|
||||
ASSERT_TRUE(bitset4->any());
|
||||
|
||||
auto bitset5 = index->Range(test_min, true, test_max, true);
|
||||
ASSERT_EQ(arr.size(), bitset5->size());
|
||||
ASSERT_TRUE(bitset5->any());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
class TypedScalarIndexTest : public ::testing::Test {
|
||||
protected:
|
||||
|
@ -222,10 +108,8 @@ TYPED_TEST_P(TypedScalarIndexTest, Codec) {
|
|||
}
|
||||
|
||||
// TODO: it's easy to overflow for int8_t. Design more reasonable ut.
|
||||
using ArithmeticT = ::testing::Types<int8_t, int16_t, int32_t, int64_t, float, double, std::string>;
|
||||
using ScalarT = ::testing::Types<int8_t, int16_t, int32_t, int64_t, float, double>;
|
||||
|
||||
REGISTER_TYPED_TEST_CASE_P(TypedScalarIndexTest, Dummy, Constructor, In, NotIn, Range, Codec);
|
||||
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck, TypedScalarIndexTest, ArithmeticT);
|
||||
|
||||
// TODO: bool.
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck, TypedScalarIndexTest, ScalarT);
|
|
@ -9,21 +9,15 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
#include <knowhere/index/vector_index/ConfAdapterMgr.h>
|
||||
#include <knowhere/archive/KnowhereConfig.h>
|
||||
#include "pb/index_cgo_msg.pb.h"
|
||||
|
||||
#define private public
|
||||
|
||||
#include "indexbuilder/VecIndexCreator.h"
|
||||
#include "indexbuilder/index_c.h"
|
||||
#include "indexbuilder/utils.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
#include "test_utils/indexbuilder_test_utils.h"
|
||||
#include "indexbuilder/ScalarIndexCreator.h"
|
||||
|
@ -38,100 +32,44 @@ namespace indexcgo = milvus::proto::indexcgo;
|
|||
namespace schemapb = milvus::proto::schema;
|
||||
using knowhere::scalar::OperatorType;
|
||||
using milvus::indexbuilder::MapParams;
|
||||
using milvus::indexbuilder::ScalarIndexCreator;
|
||||
using milvus::indexbuilder::ScalarIndexCreatorPtr;
|
||||
using ScalarTestParams = std::pair<MapParams, MapParams>;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T> | std::is_same_v<T, std::string>>>
|
||||
inline void
|
||||
assert_in(const std::unique_ptr<ScalarIndexCreator<T>>& creator, const std::vector<T>& arr) {
|
||||
// hard to compare floating point value.
|
||||
if (std::is_floating_point_v<T>) {
|
||||
return;
|
||||
build_index(const ScalarIndexCreatorPtr& creator, const std::vector<T>& arr) {
|
||||
const int64_t dim = 8; // not important here
|
||||
auto dataset = knowhere::GenDataset(arr.size(), dim, arr.data());
|
||||
creator->Build(dataset);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
build_index(const ScalarIndexCreatorPtr& creator, const std::vector<bool>& arr) {
|
||||
schemapb::BoolArray pbarr;
|
||||
for (auto b : arr) {
|
||||
pbarr.add_data(b);
|
||||
}
|
||||
auto ds = GenDsFromPB(pbarr);
|
||||
|
||||
auto bitset1 = creator->index_->In(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
|
||||
auto bitset2 = creator->index_->In(1, test.get());
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->none());
|
||||
}
|
||||
creator->Build(ds);
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
assert_not_in(const std::unique_ptr<ScalarIndexCreator<T>>& creator, const std::vector<T>& arr) {
|
||||
auto bitset1 = creator->index_->NotIn(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
|
||||
auto bitset2 = creator->index_->NotIn(1, test.get());
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
assert_range(const std::unique_ptr<ScalarIndexCreator<T>>& creator, const std::vector<T>& arr) {
|
||||
auto test_min = arr[0];
|
||||
auto test_max = arr[arr.size() - 1];
|
||||
|
||||
auto bitset1 = creator->index_->Range(test_min - 1, OperatorType::GT);
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
|
||||
auto bitset2 = creator->index_->Range(test_min, OperatorType::GE);
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
|
||||
auto bitset3 = creator->index_->Range(test_max + 1, OperatorType::LT);
|
||||
ASSERT_EQ(arr.size(), bitset3->size());
|
||||
ASSERT_TRUE(bitset3->any());
|
||||
|
||||
auto bitset4 = creator->index_->Range(test_max, OperatorType::LE);
|
||||
ASSERT_EQ(arr.size(), bitset4->size());
|
||||
ASSERT_TRUE(bitset4->any());
|
||||
|
||||
auto bitset5 = creator->index_->Range(test_min, true, test_max, true);
|
||||
ASSERT_EQ(arr.size(), bitset5->size());
|
||||
ASSERT_TRUE(bitset5->any());
|
||||
delete[](char*) ds->Get<const void*>(knowhere::meta::TENSOR);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
assert_in(const std::unique_ptr<ScalarIndexCreator<std::string>>& creator, const std::vector<std::string>& arr) {
|
||||
auto bitset1 = creator->index_->In(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
build_index(const ScalarIndexCreatorPtr& creator, const std::vector<std::string>& arr) {
|
||||
schemapb::StringArray pbarr;
|
||||
*(pbarr.mutable_data()) = {arr.begin(), arr.end()};
|
||||
auto ds = GenDsFromPB(pbarr);
|
||||
|
||||
creator->Build(ds);
|
||||
|
||||
delete[](char*) ds->Get<const void*>(knowhere::meta::TENSOR);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
assert_not_in(const std::unique_ptr<ScalarIndexCreator<std::string>>& creator, const std::vector<std::string>& arr) {
|
||||
auto bitset1 = creator->index_->NotIn(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
assert_range(const std::unique_ptr<ScalarIndexCreator<std::string>>& creator, const std::vector<std::string>& arr) {
|
||||
auto test_min = arr[0];
|
||||
auto test_max = arr[arr.size() - 1];
|
||||
|
||||
auto bitset2 = creator->index_->Range(test_min, OperatorType::GE);
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
|
||||
auto bitset4 = creator->index_->Range(test_max, OperatorType::LE);
|
||||
ASSERT_EQ(arr.size(), bitset4->size());
|
||||
ASSERT_TRUE(bitset4->any());
|
||||
|
||||
auto bitset5 = creator->index_->Range(test_min, true, test_max, true);
|
||||
ASSERT_EQ(arr.size(), bitset5->size());
|
||||
ASSERT_TRUE(bitset5->any());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
|
@ -146,8 +84,7 @@ class TypedScalarIndexCreatorTest : public ::testing::Test {
|
|||
// }
|
||||
};
|
||||
|
||||
// TODO: it's easy to overflow for int8_t. Design more reasonable ut.
|
||||
using ArithmeticT = ::testing::Types<int8_t, int16_t, int32_t, int64_t, float, double>;
|
||||
using ScalarT = ::testing::Types<bool, int8_t, int16_t, int32_t, int64_t, float, double, std::string>;
|
||||
|
||||
TYPED_TEST_CASE_P(TypedScalarIndexCreatorTest);
|
||||
|
||||
|
@ -159,430 +96,36 @@ TYPED_TEST_P(TypedScalarIndexCreatorTest, Dummy) {
|
|||
|
||||
TYPED_TEST_P(TypedScalarIndexCreatorTest, Constructor) {
|
||||
using T = TypeParam;
|
||||
auto dtype = milvus::GetDType<T>();
|
||||
for (const auto& tp : GenParams<T>()) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(TypedScalarIndexCreatorTest, In) {
|
||||
using T = TypeParam;
|
||||
for (const auto& tp : GenParams<T>()) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
auto arr = GenArr<T>(nb);
|
||||
build_index<T>(creator, arr);
|
||||
assert_in<T>(creator, arr);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(TypedScalarIndexCreatorTest, NotIn) {
|
||||
using T = TypeParam;
|
||||
for (const auto& tp : GenParams<T>()) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
auto arr = GenArr<T>(nb);
|
||||
build_index<T>(creator, arr);
|
||||
assert_not_in<T>(creator, arr);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(TypedScalarIndexCreatorTest, Range) {
|
||||
using T = TypeParam;
|
||||
for (const auto& tp : GenParams<T>()) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
auto arr = GenArr<T>(nb);
|
||||
build_index<T>(creator, arr);
|
||||
assert_range<T>(creator, arr);
|
||||
auto creator = milvus::indexbuilder::CreateScalarIndex(dtype, serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(TypedScalarIndexCreatorTest, Codec) {
|
||||
using T = TypeParam;
|
||||
auto dtype = milvus::GetDType<T>();
|
||||
for (const auto& tp : GenParams<T>()) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
auto creator = milvus::indexbuilder::CreateScalarIndex(dtype, serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
auto arr = GenArr<T>(nb);
|
||||
const int64_t dim = 8; // not important here
|
||||
auto dataset = knowhere::GenDataset(arr.size(), dim, arr.data());
|
||||
creator->Build(dataset);
|
||||
|
||||
build_index<T>(creator, arr);
|
||||
auto binary_set = creator->Serialize();
|
||||
auto copy_creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
auto copy_creator = milvus::indexbuilder::CreateScalarIndex(dtype, serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
copy_creator->Load(binary_set);
|
||||
assert_in<T>(copy_creator, arr);
|
||||
assert_not_in<T>(copy_creator, arr);
|
||||
assert_range<T>(copy_creator, arr);
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_CASE_P(TypedScalarIndexCreatorTest, Dummy, Constructor, In, NotIn, Range, Codec);
|
||||
REGISTER_TYPED_TEST_CASE_P(TypedScalarIndexCreatorTest, Dummy, Constructor, Codec);
|
||||
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck, TypedScalarIndexCreatorTest, ArithmeticT);
|
||||
|
||||
class BoolIndexTest : public ::testing::Test {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
n = 8;
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
*(all_true.mutable_data()->Add()) = true;
|
||||
*(all_false.mutable_data()->Add()) = false;
|
||||
*(half.mutable_data()->Add()) = (i % 2) == 0;
|
||||
}
|
||||
|
||||
all_true_ds = GenDsFromPB(all_true);
|
||||
all_false_ds = GenDsFromPB(all_false);
|
||||
half_ds = GenDsFromPB(half);
|
||||
|
||||
GenTestParams();
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
delete[](char*)(all_true_ds->Get<const void*>(knowhere::meta::TENSOR));
|
||||
delete[](char*) all_false_ds->Get<const void*>(knowhere::meta::TENSOR);
|
||||
delete[](char*) half_ds->Get<const void*>(knowhere::meta::TENSOR);
|
||||
}
|
||||
|
||||
private:
|
||||
void
|
||||
GenTestParams() {
|
||||
params = GenBoolParams();
|
||||
}
|
||||
|
||||
protected:
|
||||
schemapb::BoolArray all_true;
|
||||
schemapb::BoolArray all_false;
|
||||
schemapb::BoolArray half;
|
||||
knowhere::DatasetPtr all_true_ds;
|
||||
knowhere::DatasetPtr all_false_ds;
|
||||
knowhere::DatasetPtr half_ds;
|
||||
size_t n;
|
||||
std::vector<ScalarTestParams> params;
|
||||
};
|
||||
|
||||
TEST_F(BoolIndexTest, Constructor) {
|
||||
using T = bool;
|
||||
for (const auto& tp : params) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(BoolIndexTest, In) {
|
||||
using T = bool;
|
||||
for (const auto& tp : params) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
|
||||
auto true_test = std::make_unique<bool>(true);
|
||||
auto false_test = std::make_unique<bool>(false);
|
||||
|
||||
{
|
||||
auto creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
|
||||
creator->Build(all_true_ds);
|
||||
|
||||
auto bitset1 = creator->index_->In(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
|
||||
auto bitset2 = creator->index_->In(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->none());
|
||||
}
|
||||
|
||||
{
|
||||
auto creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
|
||||
creator->Build(all_false_ds);
|
||||
|
||||
auto bitset1 = creator->index_->In(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
|
||||
auto bitset2 = creator->index_->In(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
}
|
||||
|
||||
{
|
||||
auto creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
|
||||
creator->Build(half_ds);
|
||||
|
||||
auto bitset1 = creator->index_->In(1, true_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset1->test(i), (i % 2) == 0);
|
||||
}
|
||||
|
||||
auto bitset2 = creator->index_->In(1, false_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset2->test(i), (i % 2) != 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(BoolIndexTest, NotIn) {
|
||||
using T = bool;
|
||||
for (const auto& tp : params) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
|
||||
auto true_test = std::make_unique<bool>(true);
|
||||
auto false_test = std::make_unique<bool>(false);
|
||||
|
||||
{
|
||||
auto creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
|
||||
creator->Build(all_true_ds);
|
||||
|
||||
auto bitset1 = creator->index_->NotIn(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
|
||||
auto bitset2 = creator->index_->NotIn(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
}
|
||||
|
||||
{
|
||||
auto creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
|
||||
creator->Build(all_false_ds);
|
||||
|
||||
auto bitset1 = creator->index_->NotIn(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
|
||||
auto bitset2 = creator->index_->NotIn(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->none());
|
||||
}
|
||||
|
||||
{
|
||||
auto creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
|
||||
creator->Build(half_ds);
|
||||
|
||||
auto bitset1 = creator->index_->NotIn(1, true_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset1->test(i), (i % 2) != 0);
|
||||
}
|
||||
|
||||
auto bitset2 = creator->index_->NotIn(1, false_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset2->test(i), (i % 2) == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(BoolIndexTest, Codec) {
|
||||
using T = bool;
|
||||
for (const auto& tp : params) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
|
||||
auto true_test = std::make_unique<bool>(true);
|
||||
auto false_test = std::make_unique<bool>(false);
|
||||
|
||||
{
|
||||
auto creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
|
||||
creator->Build(all_true_ds);
|
||||
|
||||
auto copy_creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
copy_creator->Load(creator->Serialize());
|
||||
|
||||
auto bitset1 = copy_creator->index_->NotIn(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
|
||||
auto bitset2 = copy_creator->index_->NotIn(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
}
|
||||
|
||||
{
|
||||
auto creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
|
||||
creator->Build(all_false_ds);
|
||||
|
||||
auto copy_creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
copy_creator->Load(creator->Serialize());
|
||||
|
||||
auto bitset1 = copy_creator->index_->NotIn(1, true_test.get());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
|
||||
auto bitset2 = copy_creator->index_->NotIn(1, false_test.get());
|
||||
ASSERT_TRUE(bitset2->none());
|
||||
}
|
||||
|
||||
{
|
||||
auto creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
|
||||
creator->Build(half_ds);
|
||||
|
||||
auto copy_creator = std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(),
|
||||
serialized_index_params.c_str());
|
||||
copy_creator->Load(creator->Serialize());
|
||||
|
||||
auto bitset1 = copy_creator->index_->NotIn(1, true_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset1->test(i), (i % 2) != 0);
|
||||
}
|
||||
|
||||
auto bitset2 = copy_creator->index_->NotIn(1, false_test.get());
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ASSERT_EQ(bitset2->test(i), (i % 2) == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class StringIndexTest : public ::testing::Test {
|
||||
void
|
||||
SetUp() override {
|
||||
size_t n = 10;
|
||||
strs = GenStrArr(n);
|
||||
*str_arr.mutable_data() = {strs.begin(), strs.end()};
|
||||
str_ds = GenDsFromPB(str_arr);
|
||||
|
||||
GenTestParams();
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
delete[](char*)(str_ds->Get<const void*>(knowhere::meta::TENSOR));
|
||||
}
|
||||
|
||||
private:
|
||||
void
|
||||
GenTestParams() {
|
||||
params = GenStringParams();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::string> strs;
|
||||
schemapb::StringArray str_arr;
|
||||
knowhere::DatasetPtr str_ds;
|
||||
std::vector<ScalarTestParams> params;
|
||||
};
|
||||
|
||||
TEST_F(StringIndexTest, Constructor) {
|
||||
using T = std::string;
|
||||
for (const auto& tp : params) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StringIndexTest, In) {
|
||||
using T = std::string;
|
||||
for (const auto& tp : params) {
|
||||
PrintMapParam(tp);
|
||||
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
creator->Build(str_ds);
|
||||
assert_in<T>(creator, strs);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StringIndexTest, NotIn) {
|
||||
using T = std::string;
|
||||
for (const auto& tp : params) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
creator->Build(str_ds);
|
||||
assert_not_in<T>(creator, strs);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StringIndexTest, Range) {
|
||||
using T = std::string;
|
||||
for (const auto& tp : params) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
creator->Build(str_ds);
|
||||
assert_range<T>(creator, strs);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StringIndexTest, Codec) {
|
||||
using T = std::string;
|
||||
for (const auto& tp : params) {
|
||||
auto type_params = tp.first;
|
||||
auto index_params = tp.second;
|
||||
auto serialized_type_params = generate_type_params(type_params);
|
||||
auto serialized_index_params = generate_index_params(index_params);
|
||||
|
||||
auto creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
creator->Build(str_ds);
|
||||
|
||||
auto copy_creator =
|
||||
std::make_unique<ScalarIndexCreator<T>>(serialized_type_params.c_str(), serialized_index_params.c_str());
|
||||
auto binary_set = creator->Serialize();
|
||||
copy_creator->Load(binary_set);
|
||||
assert_in<std::string>(copy_creator, strs);
|
||||
assert_not_in<std::string>(copy_creator, strs);
|
||||
assert_range<std::string>(copy_creator, strs);
|
||||
}
|
||||
}
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck, TypedScalarIndexCreatorTest, ScalarT);
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
#include <knowhere/archive/KnowhereConfig.h>
|
||||
|
||||
#define private public
|
||||
|
||||
#include "index/Index.h"
|
||||
#include "index/ScalarIndex.h"
|
||||
#include "index/StringIndex.h"
|
||||
#include "index/StringIndexMarisa.h"
|
||||
#include "test_utils/indexbuilder_test_utils.h"
|
||||
|
||||
constexpr int64_t nb = 100;
|
||||
namespace schemapb = milvus::proto::schema;
|
||||
|
||||
class StringIndexBaseTest : public ::testing::Test {
|
||||
void
|
||||
SetUp() override {
|
||||
size_t n = 10;
|
||||
strs = GenStrArr(n);
|
||||
*str_arr.mutable_data() = {strs.begin(), strs.end()};
|
||||
str_ds = GenDsFromPB(str_arr);
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
delete[](char*)(str_ds->Get<const void*>(knowhere::meta::TENSOR));
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::string> strs;
|
||||
schemapb::StringArray str_arr;
|
||||
knowhere::DatasetPtr str_ds;
|
||||
};
|
||||
|
||||
class StringIndexMarisaTest : public StringIndexBaseTest {};
|
||||
|
||||
TEST_F(StringIndexMarisaTest, Constructor) {
|
||||
auto index = milvus::scalar::CreateStringIndexMarisa();
|
||||
}
|
||||
|
||||
TEST_F(StringIndexMarisaTest, Build) {
|
||||
auto index = milvus::scalar::CreateStringIndexMarisa();
|
||||
index->Build(strs.size(), strs.data());
|
||||
}
|
||||
|
||||
TEST_F(StringIndexMarisaTest, BuildWithDataset) {
|
||||
auto index = milvus::scalar::CreateStringIndexMarisa();
|
||||
index->BuildWithDataset(str_ds);
|
||||
}
|
||||
|
||||
TEST_F(StringIndexMarisaTest, In) {
|
||||
auto index = milvus::scalar::CreateStringIndexMarisa();
|
||||
index->BuildWithDataset(str_ds);
|
||||
auto bitset = index->In(strs.size(), strs.data());
|
||||
ASSERT_EQ(bitset->size(), strs.size());
|
||||
ASSERT_TRUE(bitset->any());
|
||||
}
|
||||
|
||||
TEST_F(StringIndexMarisaTest, NotIn) {
|
||||
auto index = milvus::scalar::CreateStringIndexMarisa();
|
||||
index->BuildWithDataset(str_ds);
|
||||
auto bitset = index->NotIn(strs.size(), strs.data());
|
||||
ASSERT_EQ(bitset->size(), strs.size());
|
||||
ASSERT_TRUE(bitset->none());
|
||||
}
|
||||
|
||||
TEST_F(StringIndexMarisaTest, Range) {
|
||||
auto index = milvus::scalar::CreateStringIndexMarisa();
|
||||
index->BuildWithDataset(str_ds);
|
||||
|
||||
ASSERT_ANY_THROW(index->Range("not important", milvus::scalar::OperatorType::LE));
|
||||
ASSERT_ANY_THROW(index->Range("not important", true, "not important", true));
|
||||
}
|
||||
|
||||
TEST_F(StringIndexMarisaTest, PrefixMatch) {
|
||||
auto index = milvus::scalar::CreateStringIndexMarisa();
|
||||
index->BuildWithDataset(str_ds);
|
||||
|
||||
for (size_t i = 0; i < strs.size(); i++) {
|
||||
auto str = strs[i];
|
||||
auto bitset = index->PrefixMatch(str);
|
||||
ASSERT_EQ(bitset->size(), strs.size());
|
||||
ASSERT_TRUE(bitset->test(i));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StringIndexMarisaTest, Query) {
|
||||
auto index = milvus::scalar::CreateStringIndexMarisa();
|
||||
index->BuildWithDataset(str_ds);
|
||||
|
||||
{
|
||||
auto ds = knowhere::GenDataset(strs.size(), 8, strs.data());
|
||||
ds->Set<milvus::scalar::OperatorType>(milvus::scalar::OPERATOR_TYPE, milvus::scalar::OperatorType::InOp);
|
||||
auto bitset = index->Query(ds);
|
||||
ASSERT_TRUE(bitset->any());
|
||||
}
|
||||
|
||||
{
|
||||
auto ds = knowhere::GenDataset(strs.size(), 8, strs.data());
|
||||
ds->Set<milvus::scalar::OperatorType>(milvus::scalar::OPERATOR_TYPE, milvus::scalar::OperatorType::NotInOp);
|
||||
auto bitset = index->Query(ds);
|
||||
ASSERT_TRUE(bitset->none());
|
||||
}
|
||||
|
||||
{
|
||||
auto ds = std::make_shared<knowhere::Dataset>();
|
||||
ds->Set<milvus::scalar::OperatorType>(milvus::scalar::OPERATOR_TYPE, milvus::scalar::OperatorType::GE);
|
||||
ds->Set<std::string>(milvus::scalar::RANGE_VALUE, "range");
|
||||
ASSERT_ANY_THROW(index->Query(ds));
|
||||
}
|
||||
|
||||
{
|
||||
auto ds = std::make_shared<knowhere::Dataset>();
|
||||
ds->Set<milvus::scalar::OperatorType>(milvus::scalar::OPERATOR_TYPE, milvus::scalar::OperatorType::RangeOp);
|
||||
ds->Set<std::string>(milvus::scalar::LOWER_BOUND_VALUE, "range");
|
||||
ds->Set<std::string>(milvus::scalar::UPPER_BOUND_VALUE, "range");
|
||||
ds->Set<bool>(milvus::scalar::LOWER_BOUND_INCLUSIVE, true);
|
||||
ds->Set<bool>(milvus::scalar::UPPER_BOUND_INCLUSIVE, true);
|
||||
ASSERT_ANY_THROW(index->Query(ds));
|
||||
}
|
||||
|
||||
{
|
||||
for (size_t i = 0; i < strs.size(); i++) {
|
||||
auto ds = std::make_shared<knowhere::Dataset>();
|
||||
ds->Set<milvus::scalar::OperatorType>(milvus::scalar::OPERATOR_TYPE,
|
||||
milvus::scalar::OperatorType::PrefixMatchOp);
|
||||
ds->Set<std::string>(milvus::scalar::PREFIX_VALUE, std::move(strs[i]));
|
||||
auto bitset = index->Query(ds);
|
||||
ASSERT_EQ(bitset->size(), strs.size());
|
||||
ASSERT_TRUE(bitset->test(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StringIndexMarisaTest, Codec) {
|
||||
auto index = milvus::scalar::CreateStringIndexMarisa();
|
||||
index->BuildWithDataset(str_ds);
|
||||
|
||||
auto copy_index = milvus::scalar::CreateStringIndexMarisa();
|
||||
|
||||
{
|
||||
auto binary_set = index->Serialize(nullptr);
|
||||
copy_index->Load(binary_set);
|
||||
}
|
||||
|
||||
{
|
||||
auto bitset = copy_index->In(strs.size(), strs.data());
|
||||
ASSERT_EQ(bitset->size(), strs.size());
|
||||
ASSERT_TRUE(bitset->any());
|
||||
}
|
||||
|
||||
{
|
||||
auto bitset = copy_index->NotIn(strs.size(), strs.data());
|
||||
ASSERT_EQ(bitset->size(), strs.size());
|
||||
ASSERT_TRUE(bitset->none());
|
||||
}
|
||||
|
||||
{
|
||||
ASSERT_ANY_THROW(copy_index->Range("not important", milvus::scalar::OperatorType::LE));
|
||||
ASSERT_ANY_THROW(copy_index->Range("not important", true, "not important", true));
|
||||
}
|
||||
|
||||
{
|
||||
for (size_t i = 0; i < strs.size(); i++) {
|
||||
auto str = strs[i];
|
||||
auto bitset = copy_index->PrefixMatch(str);
|
||||
ASSERT_EQ(bitset->size(), strs.size());
|
||||
ASSERT_TRUE(bitset->test(i));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
using milvus::scalar::ScalarIndexPtr;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
inline void
|
||||
assert_in(const ScalarIndexPtr<T>& index, const std::vector<T>& arr) {
|
||||
// hard to compare floating point value.
|
||||
if (std::is_floating_point_v<T>) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto bitset1 = index->In(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
|
||||
auto bitset2 = index->In(1, test.get());
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->none());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
assert_not_in(const ScalarIndexPtr<T>& index, const std::vector<T>& arr) {
|
||||
auto bitset1 = index->NotIn(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
|
||||
auto bitset2 = index->NotIn(1, test.get());
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
assert_range(const ScalarIndexPtr<T>& index, const std::vector<T>& arr) {
|
||||
auto test_min = arr[0];
|
||||
auto test_max = arr[arr.size() - 1];
|
||||
|
||||
auto bitset1 = index->Range(test_min - 1, milvus::scalar::OperatorType::GT);
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
|
||||
auto bitset2 = index->Range(test_min, milvus::scalar::OperatorType::GE);
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
|
||||
auto bitset3 = index->Range(test_max + 1, milvus::scalar::OperatorType::LT);
|
||||
ASSERT_EQ(arr.size(), bitset3->size());
|
||||
ASSERT_TRUE(bitset3->any());
|
||||
|
||||
auto bitset4 = index->Range(test_max, milvus::scalar::OperatorType::LE);
|
||||
ASSERT_EQ(arr.size(), bitset4->size());
|
||||
ASSERT_TRUE(bitset4->any());
|
||||
|
||||
auto bitset5 = index->Range(test_min, true, test_max, true);
|
||||
ASSERT_EQ(arr.size(), bitset5->size());
|
||||
ASSERT_TRUE(bitset5->any());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
assert_in(const ScalarIndexPtr<std::string>& index, const std::vector<std::string>& arr) {
|
||||
auto bitset1 = index->In(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->any());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
assert_not_in(const ScalarIndexPtr<std::string>& index, const std::vector<std::string>& arr) {
|
||||
auto bitset1 = index->NotIn(arr.size(), arr.data());
|
||||
ASSERT_EQ(arr.size(), bitset1->size());
|
||||
ASSERT_TRUE(bitset1->none());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
assert_range(const ScalarIndexPtr<std::string>& index, const std::vector<std::string>& arr) {
|
||||
auto test_min = arr[0];
|
||||
auto test_max = arr[arr.size() - 1];
|
||||
|
||||
auto bitset2 = index->Range(test_min, milvus::scalar::OperatorType::GE);
|
||||
ASSERT_EQ(arr.size(), bitset2->size());
|
||||
ASSERT_TRUE(bitset2->any());
|
||||
|
||||
auto bitset4 = index->Range(test_max, milvus::scalar::OperatorType::LE);
|
||||
ASSERT_EQ(arr.size(), bitset4->size());
|
||||
ASSERT_TRUE(bitset4->any());
|
||||
|
||||
auto bitset5 = index->Range(test_min, true, test_max, true);
|
||||
ASSERT_EQ(arr.size(), bitset5->size());
|
||||
ASSERT_TRUE(bitset5->any());
|
||||
}
|
||||
} // namespace
|
|
@ -20,6 +20,8 @@
|
|||
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
#include <knowhere/index/vector_index/VecIndexFactory.h>
|
||||
#include <index/ScalarIndex.h>
|
||||
#include <index/StringIndex.h>
|
||||
|
||||
#include "pb/index_cgo_msg.pb.h"
|
||||
|
||||
|
@ -28,7 +30,6 @@
|
|||
#include "DataGen.h"
|
||||
#include "indexbuilder/utils.h"
|
||||
#include "indexbuilder/helper.h"
|
||||
#define private public
|
||||
#include "indexbuilder/ScalarIndexCreator.h"
|
||||
|
||||
constexpr int64_t DIM = 8;
|
||||
|
@ -44,6 +45,8 @@ using knowhere::scalar::OperatorType;
|
|||
using milvus::indexbuilder::MapParams;
|
||||
using milvus::indexbuilder::ScalarIndexCreator;
|
||||
using ScalarTestParams = std::pair<MapParams, MapParams>;
|
||||
using milvus::scalar::ScalarIndexPtr;
|
||||
using milvus::scalar::StringIndexPtr;
|
||||
|
||||
namespace {
|
||||
auto
|
||||
|
@ -397,15 +400,6 @@ GenArr<std::string>(int64_t n) {
|
|||
return GenStrArr(n);
|
||||
}
|
||||
|
||||
template <typename T, typename = typename std::enable_if_t<std::is_arithmetic_v<T>>>
|
||||
inline std::vector<ScalarTestParams>
|
||||
GenParams() {
|
||||
std::vector<ScalarTestParams> ret;
|
||||
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "inverted_index"}}));
|
||||
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "flat"}}));
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<ScalarTestParams>
|
||||
GenBoolParams() {
|
||||
std::vector<ScalarTestParams> ret;
|
||||
|
@ -417,7 +411,24 @@ GenBoolParams() {
|
|||
std::vector<ScalarTestParams>
|
||||
GenStringParams() {
|
||||
std::vector<ScalarTestParams> ret;
|
||||
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "marisa-trie"}}));
|
||||
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "marisa"}}));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename = typename std::enable_if_t<std::is_arithmetic_v<T> | std::is_same_v<std::string, T>>>
|
||||
inline std::vector<ScalarTestParams>
|
||||
GenParams() {
|
||||
if (std::is_same_v<std::string, T>) {
|
||||
return GenStringParams();
|
||||
}
|
||||
|
||||
if (std::is_same_v<T, bool>) {
|
||||
return GenBoolParams();
|
||||
}
|
||||
|
||||
std::vector<ScalarTestParams> ret;
|
||||
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "inverted_index"}}));
|
||||
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "flat"}}));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -438,14 +449,6 @@ PrintMapParams(const std::vector<ScalarTestParams>& tps) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
build_index(const std::unique_ptr<ScalarIndexCreator<T>>& creator, const std::vector<T>& arr) {
|
||||
const int64_t dim = 8; // not important here
|
||||
auto dataset = knowhere::GenDataset(arr.size(), dim, arr.data());
|
||||
creator->Build(dataset);
|
||||
}
|
||||
|
||||
// memory generated by this function should be freed by the caller.
|
||||
auto
|
||||
GenDsFromPB(const google::protobuf::Message& msg) {
|
||||
|
@ -453,4 +456,17 @@ GenDsFromPB(const google::protobuf::Message& msg) {
|
|||
msg.SerializeToArray(data, msg.ByteSize());
|
||||
return knowhere::GenDataset(msg.ByteSize(), 8, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<std::string>
|
||||
GetIndexTypes() {
|
||||
return std::vector<std::string>{"inverted_index"};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<std::string>
|
||||
GetIndexTypes<std::string>() {
|
||||
return std::vector<std::string>{"marisa"};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -4,9 +4,9 @@ package indexcgowrapper
|
|||
|
||||
#cgo CFLAGS: -I${SRCDIR}/../../core/output/include
|
||||
|
||||
#cgo darwin LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_indexbuilder -lmilvus_common -Wl,-rpath,"${SRCDIR}/../../core/output/lib"
|
||||
#cgo linux LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_indexbuilder -lmilvus_common -Wl,-rpath=${SRCDIR}/../../core/output/lib
|
||||
#cgo windows LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_indexbuilder -lmilvus_common -Wl,-rpath=${SRCDIR}/../../core/output/lib
|
||||
#cgo darwin LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_common -Wl,-rpath,"${SRCDIR}/../../core/output/lib"
|
||||
#cgo linux LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_common -Wl,-rpath=${SRCDIR}/../../core/output/lib
|
||||
#cgo windows LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_common -Wl,-rpath=${SRCDIR}/../../core/output/lib
|
||||
|
||||
#include <stdlib.h> // free
|
||||
#include "indexbuilder/index_c.h"
|
||||
|
|
|
@ -4,9 +4,9 @@ package indexcgowrapper
|
|||
|
||||
#cgo CFLAGS: -I${SRCDIR}/../../core/output/include
|
||||
|
||||
#cgo darwin LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_indexbuilder -lmilvus_common -Wl,-rpath,"${SRCDIR}/../../core/output/lib"
|
||||
#cgo linux LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_indexbuilder -lmilvus_common -Wl,-rpath=${SRCDIR}/../../core/output/lib
|
||||
#cgo windows LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_common -lmilvus_indexbuilder -Wl,-rpath=${SRCDIR}/../core/output/lib
|
||||
#cgo darwin LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_indexbuilder -Wl,-rpath,"${SRCDIR}/../../core/output/lib"
|
||||
#cgo linux LDFLAGS: -L${SRCDIR}/../../core/output/lib -lmilvus_indexbuilder -Wl,-rpath=${SRCDIR}/../../core/output/lib
|
||||
#cgo windows LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_indexbuilder -Wl,-rpath=${SRCDIR}/../core/output/lib
|
||||
|
||||
#include <stdlib.h> // free
|
||||
#include "indexbuilder/index_c.h"
|
||||
|
|
Loading…
Reference in New Issue