#include #include #include "tantivy-binding.h" namespace milvus::tantivy { struct RustArrayWrapper { explicit RustArrayWrapper(RustArray array) : array_(array) { } RustArrayWrapper(RustArrayWrapper&) = delete; RustArrayWrapper& operator=(RustArrayWrapper&) = delete; RustArrayWrapper(RustArrayWrapper&& other) noexcept { array_.array = other.array_.array; array_.len = other.array_.len; array_.cap = other.array_.cap; other.array_.array = nullptr; other.array_.len = 0; other.array_.cap = 0; } RustArrayWrapper& operator=(RustArrayWrapper&& other) noexcept { if (this != &other) { free(); array_.array = other.array_.array; array_.len = other.array_.len; array_.cap = other.array_.cap; other.array_.array = nullptr; other.array_.len = 0; other.array_.cap = 0; } return *this; } ~RustArrayWrapper() { free(); } void debug() { std::stringstream ss; ss << "[ "; for (int i = 0; i < array_.len; i++) { ss << array_.array[i] << " "; } ss << "]"; std::cout << ss.str() << std::endl; } RustArray array_; private: void free() { if (array_.array != nullptr) { free_rust_array(array_); } } }; template inline TantivyDataType guess_data_type() { if constexpr (std::is_same_v) { return TantivyDataType::Bool; } if constexpr (std::is_integral_v) { return TantivyDataType::I64; } if constexpr (std::is_floating_point_v) { return TantivyDataType::F64; } throw fmt::format("guess_data_type: unsupported data type: {}", typeid(T).name()); } struct TantivyIndexWrapper { using IndexWriter = void*; using IndexReader = void*; TantivyIndexWrapper() = default; TantivyIndexWrapper(TantivyIndexWrapper&) = delete; TantivyIndexWrapper& operator=(TantivyIndexWrapper&) = delete; TantivyIndexWrapper(TantivyIndexWrapper&& other) noexcept { writer_ = other.writer_; reader_ = other.reader_; finished_ = other.finished_; path_ = other.path_; other.writer_ = nullptr; other.reader_ = nullptr; other.finished_ = false; other.path_ = ""; } TantivyIndexWrapper& operator=(TantivyIndexWrapper&& other) noexcept { if (this != &other) { free(); writer_ = other.writer_; reader_ = other.reader_; path_ = other.path_; finished_ = other.finished_; other.writer_ = nullptr; other.reader_ = nullptr; other.finished_ = false; other.path_ = ""; } return *this; } TantivyIndexWrapper(const char* field_name, TantivyDataType data_type, const char* path) { writer_ = tantivy_create_index(field_name, data_type, path); path_ = std::string(path); } explicit TantivyIndexWrapper(const char* path) { assert(tantivy_index_exist(path)); reader_ = tantivy_load_index(path); path_ = std::string(path); } ~TantivyIndexWrapper() { free(); } template void add_data(const T* array, uintptr_t len) { assert(!finished_); if constexpr (std::is_same_v) { tantivy_index_add_bools(writer_, array, len); return; } if constexpr (std::is_same_v) { tantivy_index_add_int8s(writer_, array, len); return; } if constexpr (std::is_same_v) { tantivy_index_add_int16s(writer_, array, len); return; } if constexpr (std::is_same_v) { tantivy_index_add_int32s(writer_, array, len); return; } if constexpr (std::is_same_v) { tantivy_index_add_int64s(writer_, array, len); return; } if constexpr (std::is_same_v) { tantivy_index_add_f32s(writer_, array, len); return; } if constexpr (std::is_same_v) { tantivy_index_add_f64s(writer_, array, len); return; } if constexpr (std::is_same_v) { // TODO: not very efficient, a lot of overhead due to rust-ffi call. for (uintptr_t i = 0; i < len; i++) { tantivy_index_add_keyword( writer_, static_cast(array)[i].c_str()); } return; } throw fmt::format("InvertedIndex.add_data: unsupported data type: {}", typeid(T).name()); } inline void finish() { if (!finished_) { tantivy_finish_index(writer_); writer_ = nullptr; reader_ = tantivy_load_index(path_.c_str()); finished_ = true; } } inline uint32_t count() { return tantivy_index_count(reader_); } public: template RustArrayWrapper term_query(T term) { auto array = [&]() { if constexpr (std::is_same_v) { return tantivy_term_query_bool(reader_, term); } if constexpr (std::is_integral_v) { return tantivy_term_query_i64(reader_, static_cast(term)); } if constexpr (std::is_floating_point_v) { return tantivy_term_query_f64(reader_, static_cast(term)); } if constexpr (std::is_same_v) { return tantivy_term_query_keyword( reader_, static_cast(term).c_str()); } throw fmt::format( "InvertedIndex.term_query: unsupported data type: {}", typeid(T).name()); }(); return RustArrayWrapper(array); } template RustArrayWrapper lower_bound_range_query(T lower_bound, bool inclusive) { auto array = [&]() { if constexpr (std::is_integral_v) { return tantivy_lower_bound_range_query_i64( reader_, static_cast(lower_bound), inclusive); } if constexpr (std::is_floating_point_v) { return tantivy_lower_bound_range_query_f64( reader_, static_cast(lower_bound), inclusive); } if constexpr (std::is_same_v) { return tantivy_lower_bound_range_query_keyword( reader_, static_cast(lower_bound).c_str(), inclusive); } throw fmt::format( "InvertedIndex.lower_bound_range_query: unsupported data type: " "{}", typeid(T).name()); }(); return RustArrayWrapper(array); } template RustArrayWrapper upper_bound_range_query(T upper_bound, bool inclusive) { auto array = [&]() { if constexpr (std::is_integral_v) { return tantivy_upper_bound_range_query_i64( reader_, static_cast(upper_bound), inclusive); } if constexpr (std::is_floating_point_v) { return tantivy_upper_bound_range_query_f64( reader_, static_cast(upper_bound), inclusive); } if constexpr (std::is_same_v) { return tantivy_upper_bound_range_query_keyword( reader_, static_cast(upper_bound).c_str(), inclusive); } throw fmt::format( "InvertedIndex.upper_bound_range_query: unsupported data type: " "{}", typeid(T).name()); }(); return RustArrayWrapper(array); } template RustArrayWrapper range_query(T lower_bound, T upper_bound, bool lb_inclusive, bool ub_inclusive) { auto array = [&]() { if constexpr (std::is_integral_v) { return tantivy_range_query_i64( reader_, static_cast(lower_bound), static_cast(upper_bound), lb_inclusive, ub_inclusive); } if constexpr (std::is_floating_point_v) { return tantivy_range_query_f64(reader_, static_cast(lower_bound), static_cast(upper_bound), lb_inclusive, ub_inclusive); } if constexpr (std::is_same_v) { return tantivy_range_query_keyword( reader_, static_cast(lower_bound).c_str(), static_cast(upper_bound).c_str(), lb_inclusive, ub_inclusive); } throw fmt::format( "InvertedIndex.range_query: unsupported data type: {}", typeid(T).name()); }(); return RustArrayWrapper(array); } RustArrayWrapper prefix_query(const std::string& prefix) { auto array = tantivy_prefix_query_keyword(reader_, prefix.c_str()); return RustArrayWrapper(array); } RustArrayWrapper regex_query(const std::string& pattern) { auto array = tantivy_regex_query(reader_, pattern.c_str()); return RustArrayWrapper(array); } public: inline IndexWriter get_writer() { return writer_; } inline IndexReader get_reader() { return reader_; } private: void check_search() { // TODO } void free() { if (writer_ != nullptr) { tantivy_free_index_writer(writer_); } if (reader_ != nullptr) { tantivy_free_index_reader(reader_); } } private: bool finished_ = false; IndexWriter writer_ = nullptr; IndexReader reader_ = nullptr; std::string path_; }; } // namespace milvus::tantivy