enhance: optimize search performance of inverted index (#29794)

issue: #29793 
Use `DocSetCollector` instead of `TopDocsCollector`, which will avoid
scoring and sorting.

---------

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/29871/head
Jiquan Long 2024-01-11 11:12:49 +08:00 committed by GitHub
parent 5164d30287
commit 67ab5be15a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 386 additions and 79 deletions

View File

@ -64,3 +64,10 @@ target_link_libraries(test_tantivy
boost_filesystem boost_filesystem
dl dl
) )
add_executable(bench_tantivy bench.cpp)
target_link_libraries(bench_tantivy
tantivy_binding
boost_filesystem
dl
)

View File

@ -0,0 +1,65 @@
#include <cstdint>
#include <cassert>
#include <boost/filesystem.hpp>
#include <iostream>
#include <random>
#include "tantivy-binding.h"
#include "tantivy-wrapper.h"
#include "time_recorder.h"
using namespace milvus::tantivy;
void
build_index(size_t n = 1000000) {
auto path = "/tmp/inverted-index/test-binding/";
boost::filesystem::remove_all(path);
boost::filesystem::create_directories(path);
auto w =
TantivyIndexWrapper("test_field_name", TantivyDataType::Keyword, path);
std::vector<std::string> arr;
arr.reserve(n);
std::default_random_engine er(42);
int64_t sample = 10000;
for (size_t i = 0; i < n; i++) {
auto x = er() % sample;
arr.push_back(std::to_string(x));
}
w.add_data<std::string>(arr.data(), arr.size());
w.finish();
assert(w.count() == n);
}
void
search(size_t repeat = 10) {
TimeRecorder tr("bench-tantivy-search");
auto path = "/tmp/inverted-index/test-binding/";
assert(tantivy_index_exist(path));
tr.RecordSection("check if index exist");
auto w = TantivyIndexWrapper(path);
auto cnt = w.count();
tr.RecordSection("count num_entities");
std::cout << "index already exist, open it, count: " << cnt << std::endl;
for (size_t i = 0; i < repeat; i++) {
w.lower_bound_range_query<std::string>(std::to_string(45), false);
tr.RecordSection("query");
}
tr.ElapseFromBegin("done");
}
int
main(int argc, char* argv[]) {
build_index(1000000);
search(10);
return 0;
}

View File

@ -1,6 +1,5 @@
use std::{env, path::PathBuf}; use std::{env, path::PathBuf};
fn main() { fn main() {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let package_name = env::var("CARGO_PKG_NAME").unwrap(); let package_name = env::var("CARGO_PKG_NAME").unwrap();

View File

@ -21,30 +21,6 @@ struct RustArray {
extern "C" { extern "C" {
void *tantivy_create_index(const char *field_name, TantivyDataType data_type, const char *path);
void tantivy_free_index_writer(void *ptr);
void tantivy_finish_index(void *ptr);
void tantivy_index_add_int8s(void *ptr, const int8_t *array, uintptr_t len);
void tantivy_index_add_int16s(void *ptr, const int16_t *array, uintptr_t len);
void tantivy_index_add_int32s(void *ptr, const int32_t *array, uintptr_t len);
void tantivy_index_add_int64s(void *ptr, const int64_t *array, uintptr_t len);
void tantivy_index_add_f32s(void *ptr, const float *array, uintptr_t len);
void tantivy_index_add_f64s(void *ptr, const double *array, uintptr_t len);
void tantivy_index_add_bools(void *ptr, const bool *array, uintptr_t len);
void tantivy_index_add_keyword(void *ptr, const char *s);
bool tantivy_index_exist(const char *path);
void free_rust_array(RustArray array); void free_rust_array(RustArray array);
void *tantivy_load_index(const char *path); void *tantivy_load_index(const char *path);
@ -97,4 +73,28 @@ RustArray tantivy_range_query_keyword(void *ptr,
RustArray tantivy_prefix_query_keyword(void *ptr, const char *prefix); RustArray tantivy_prefix_query_keyword(void *ptr, const char *prefix);
void *tantivy_create_index(const char *field_name, TantivyDataType data_type, const char *path);
void tantivy_free_index_writer(void *ptr);
void tantivy_finish_index(void *ptr);
void tantivy_index_add_int8s(void *ptr, const int8_t *array, uintptr_t len);
void tantivy_index_add_int16s(void *ptr, const int16_t *array, uintptr_t len);
void tantivy_index_add_int32s(void *ptr, const int32_t *array, uintptr_t len);
void tantivy_index_add_int64s(void *ptr, const int64_t *array, uintptr_t len);
void tantivy_index_add_f32s(void *ptr, const float *array, uintptr_t len);
void tantivy_index_add_f64s(void *ptr, const double *array, uintptr_t len);
void tantivy_index_add_bools(void *ptr, const bool *array, uintptr_t len);
void tantivy_index_add_keyword(void *ptr, const char *s);
bool tantivy_index_exist(const char *path);
} // extern "C" } // extern "C"

View File

@ -1,6 +1,5 @@
use libc::size_t; use libc::size_t;
#[repr(C)] #[repr(C)]
pub struct RustArray { pub struct RustArray {
array: *mut u32, array: *mut u32,

View File

@ -0,0 +1,59 @@
use std::collections::HashSet;
use tantivy::{
collector::{Collector, SegmentCollector},
DocId,
};
pub struct HashSetCollector;
impl Collector for HashSetCollector {
type Fruit = HashSet<DocId>;
type Child = HashSetChildCollector;
fn for_segment(
&self,
_segment_local_id: tantivy::SegmentOrdinal,
_segment: &tantivy::SegmentReader,
) -> tantivy::Result<Self::Child> {
Ok(HashSetChildCollector {
docs: HashSet::new(),
})
}
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(&self, segment_fruits: Vec<HashSet<DocId>>) -> tantivy::Result<HashSet<DocId>> {
if segment_fruits.len() == 1 {
Ok(segment_fruits.into_iter().next().unwrap())
} else {
let len: usize = segment_fruits.iter().map(|docset| docset.len()).sum();
let mut result = HashSet::with_capacity(len);
for docs in segment_fruits {
for doc in docs {
result.insert(doc);
}
}
Ok(result)
}
}
}
pub struct HashSetChildCollector {
docs: HashSet<DocId>,
}
impl SegmentCollector for HashSetChildCollector {
type Fruit = HashSet<DocId>;
fn collect(&mut self, doc: DocId, _score: tantivy::Score) {
self.docs.insert(doc);
}
fn harvest(self) -> Self::Fruit {
self.docs
}
}

View File

@ -1,16 +1,13 @@
use std::ops::Bound; use std::ops::Bound;
use std::str::FromStr; use std::str::FromStr;
use tantivy::collector::TopDocs;
use tantivy::directory::MmapDirectory; use tantivy::directory::MmapDirectory;
use tantivy::query::{Query, RangeQuery, TermQuery, RegexQuery}; use tantivy::query::{Query, RangeQuery, RegexQuery, TermQuery};
use tantivy::schema::{Field, IndexRecordOption}; use tantivy::schema::{Field, IndexRecordOption};
use tantivy::{Index, IndexReader, ReloadPolicy, Term}; use tantivy::{Index, IndexReader, ReloadPolicy, Term};
use crate::util::make_bounds; use crate::util::make_bounds;
use crate::vec_collector::VecCollector;
pub struct IndexReaderWrapper { pub struct IndexReaderWrapper {
pub field_name: String, pub field_name: String,
@ -26,9 +23,7 @@ impl IndexReaderWrapper {
.reload_policy(ReloadPolicy::Manual) .reload_policy(ReloadPolicy::Manual)
.try_into() .try_into()
.unwrap(); .unwrap();
let metas = index let metas = index.searchable_segment_metas().unwrap();
.searchable_segment_metas()
.unwrap();
let mut sum: u32 = 0; let mut sum: u32 = 0;
for meta in metas { for meta in metas {
sum += meta.max_doc(); sum += meta.max_doc();
@ -57,13 +52,10 @@ impl IndexReaderWrapper {
fn search(&self, q: &dyn Query) -> Vec<u32> { fn search(&self, q: &dyn Query) -> Vec<u32> {
let searcher = self.reader.searcher(); let searcher = self.reader.searcher();
let cnt = self.cnt; let hits = searcher.search(q, &VecCollector).unwrap();
let hits = searcher let mut ret = Vec::with_capacity(hits.len());
.search(q, &TopDocs::with_limit(cnt as usize)) for address in hits {
.unwrap(); ret.push(address);
let mut ret = Vec::new();
for (_, address) in hits {
ret.push(address.doc_id);
} }
ret ret
} }
@ -193,10 +185,7 @@ impl IndexReaderWrapper {
self.search(&q) self.search(&q)
} }
pub fn prefix_query_keyword( pub fn prefix_query_keyword(&self, prefix: &str) -> Vec<u32> {
&self,
prefix: &str,
) -> Vec<u32> {
let pattern = format!("{}(.|\n)*", prefix); let pattern = format!("{}(.|\n)*", prefix);
let q = RegexQuery::from_pattern(&pattern, self.field).unwrap(); let q = RegexQuery::from_pattern(&pattern, self.field).unwrap();
self.search(&q) self.search(&q)

View File

@ -1,7 +1,10 @@
use std::ffi::{c_char, c_void, CStr}; use std::ffi::{c_char, c_void, CStr};
use crate::{ use crate::{
index_reader::IndexReaderWrapper, util_c::tantivy_index_exist, util::{create_binding, free_binding}, array::RustArray, array::RustArray,
index_reader::IndexReaderWrapper,
util::{create_binding, free_binding},
util_c::tantivy_index_exist,
}; };
#[no_mangle] #[no_mangle]
@ -196,7 +199,10 @@ pub extern "C" fn tantivy_range_query_keyword(
} }
#[no_mangle] #[no_mangle]
pub extern "C" fn tantivy_prefix_query_keyword(ptr: *mut c_void, prefix: *const c_char) -> RustArray { pub extern "C" fn tantivy_prefix_query_keyword(
ptr: *mut c_void,
prefix: *const c_char,
) -> RustArray {
let real = ptr as *mut IndexReaderWrapper; let real = ptr as *mut IndexReaderWrapper;
unsafe { unsafe {
let c_str = CStr::from_ptr(prefix); let c_str = CStr::from_ptr(prefix);

View File

@ -4,7 +4,6 @@ use tantivy::schema::{Field, IndexRecordOption, Schema, TextFieldIndexing, TextO
use tantivy::{doc, tokenizer, Index, IndexWriter}; use tantivy::{doc, tokenizer, Index, IndexWriter};
use crate::data_type::TantivyDataType; use crate::data_type::TantivyDataType;
use crate::index_reader::IndexReaderWrapper;
pub struct IndexWriterWrapper { pub struct IndexWriterWrapper {
pub field_name: String, pub field_name: String,

View File

@ -2,7 +2,9 @@ use core::slice;
use std::ffi::{c_char, c_void, CStr}; use std::ffi::{c_char, c_void, CStr};
use crate::{ use crate::{
data_type::TantivyDataType, index_writer::IndexWriterWrapper, util::{create_binding, free_binding}, data_type::TantivyDataType,
index_writer::IndexWriterWrapper,
util::{create_binding, free_binding},
}; };
#[no_mangle] #[no_mangle]
@ -31,9 +33,7 @@ pub extern "C" fn tantivy_free_index_writer(ptr: *mut c_void) {
#[no_mangle] #[no_mangle]
pub extern "C" fn tantivy_finish_index(ptr: *mut c_void) { pub extern "C" fn tantivy_finish_index(ptr: *mut c_void) {
let real = ptr as *mut IndexWriterWrapper; let real = ptr as *mut IndexWriterWrapper;
unsafe { unsafe { Box::from_raw(real).finish() }
Box::from_raw(real).finish()
}
} }
// -------------------------build-------------------- // -------------------------build--------------------

View File

@ -1,11 +1,14 @@
mod data_type;
mod index_writer;
mod index_writer_c;
mod util;
mod util_c;
mod array; mod array;
mod data_type;
mod hashset_collector;
mod index_reader; mod index_reader;
mod index_reader_c; mod index_reader_c;
mod index_writer;
mod index_writer_c;
mod linkedlist_collector;
mod util;
mod util_c;
mod vec_collector;
pub fn add(left: usize, right: usize) -> usize { pub fn add(left: usize, right: usize) -> usize {
left + right left + right

View File

@ -0,0 +1,61 @@
use std::collections::LinkedList;
use tantivy::{
collector::{Collector, SegmentCollector},
DocId,
};
pub struct LinkedListCollector;
impl Collector for LinkedListCollector {
type Fruit = LinkedList<DocId>;
type Child = LinkedListChildCollector;
fn for_segment(
&self,
_segment_local_id: tantivy::SegmentOrdinal,
_segment: &tantivy::SegmentReader,
) -> tantivy::Result<Self::Child> {
Ok(LinkedListChildCollector {
docs: LinkedList::new(),
})
}
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(
&self,
segment_fruits: Vec<LinkedList<DocId>>,
) -> tantivy::Result<LinkedList<DocId>> {
if segment_fruits.len() == 1 {
Ok(segment_fruits.into_iter().next().unwrap())
} else {
let mut result = LinkedList::new();
for docs in segment_fruits {
for doc in docs {
result.push_front(doc);
}
}
Ok(result)
}
}
}
pub struct LinkedListChildCollector {
docs: LinkedList<DocId>,
}
impl SegmentCollector for LinkedListChildCollector {
type Fruit = LinkedList<DocId>;
fn collect(&mut self, doc: DocId, _score: tantivy::Score) {
self.docs.push_front(doc);
}
fn harvest(self) -> Self::Fruit {
self.docs
}
}

View File

@ -1,5 +1,5 @@
use std::ffi::c_void;
use std::ops::Bound; use std::ops::Bound;
use std::ffi::{c_void};
use tantivy::{directory::MmapDirectory, Index}; use tantivy::{directory::MmapDirectory, Index};

View File

@ -1,6 +1,6 @@
use std::ffi::{c_char, CStr}; use std::ffi::{c_char, CStr};
use crate::{util::index_exist}; use crate::util::index_exist;
#[no_mangle] #[no_mangle]
pub extern "C" fn tantivy_index_exist(path: *const c_char) -> bool { pub extern "C" fn tantivy_index_exist(path: *const c_char) -> bool {

View File

@ -0,0 +1,55 @@
use tantivy::{
collector::{Collector, SegmentCollector},
DocId,
};
pub struct VecCollector;
impl Collector for VecCollector {
type Fruit = Vec<DocId>;
type Child = VecChildCollector;
fn for_segment(
&self,
_segment_local_id: tantivy::SegmentOrdinal,
_segment: &tantivy::SegmentReader,
) -> tantivy::Result<Self::Child> {
Ok(VecChildCollector { docs: Vec::new() })
}
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(&self, segment_fruits: Vec<Vec<DocId>>) -> tantivy::Result<Vec<DocId>> {
if segment_fruits.len() == 1 {
Ok(segment_fruits.into_iter().next().unwrap())
} else {
let len: usize = segment_fruits.iter().map(|docset| docset.len()).sum();
let mut result = Vec::with_capacity(len);
for docs in segment_fruits {
for doc in docs {
result.push(doc);
}
}
Ok(result)
}
}
}
pub struct VecChildCollector {
docs: Vec<DocId>,
}
impl SegmentCollector for VecChildCollector {
type Fruit = Vec<DocId>;
fn collect(&mut self, doc: DocId, _score: tantivy::Score) {
self.docs.push(doc);
}
fn harvest(self) -> Self::Fruit {
self.docs
}
}

View File

@ -0,0 +1,65 @@
#pragma once
#include <chrono>
#include <iostream>
#include <string>
class TimeRecorder {
using stdclock = std::chrono::high_resolution_clock;
public:
// trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5
explicit TimeRecorder(std::string hdr, int64_t log_level = 0)
: header_(std::move(hdr)), log_level_(log_level) {
start_ = last_ = stdclock::now();
}
virtual ~TimeRecorder() = default;
double
RecordSection(const std::string& msg) {
stdclock::time_point curr = stdclock::now();
double span =
(std::chrono::duration<double, std::micro>(curr - last_)).count();
last_ = curr;
PrintTimeRecord(msg, span);
return span;
}
double
ElapseFromBegin(const std::string& msg) {
stdclock::time_point curr = stdclock::now();
double span =
(std::chrono::duration<double, std::micro>(curr - start_)).count();
PrintTimeRecord(msg, span);
return span;
}
static std::string
GetTimeSpanStr(double span) {
std::string str_ms = std::to_string(span * 0.001) + " ms";
return str_ms;
}
private:
void
PrintTimeRecord(const std::string& msg, double span) {
std::string str_log;
if (!header_.empty()) {
str_log += header_ + ": ";
}
str_log += msg;
str_log += " (";
str_log += TimeRecorder::GetTimeSpanStr(span);
str_log += ")";
std::cout << str_log << std::endl;
}
private:
std::string header_;
stdclock::time_point start_;
stdclock::time_point last_;
int64_t log_level_;
};