mirror of https://github.com/milvus-io/milvus.git
Support span in SegmentGrowing, refine vector_trait
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>pull/4973/head^2
parent
04e2062750
commit
7d81222550
|
@ -86,14 +86,8 @@ class Span<T, typename std::enable_if_t<std::is_fundamental_v<T>>> {
|
|||
const int64_t row_count_;
|
||||
};
|
||||
|
||||
namespace segcore {
|
||||
class VectorTrait;
|
||||
class FloatVector;
|
||||
class BinaryVector;
|
||||
} // namespace segcore
|
||||
|
||||
template <typename VectorType>
|
||||
class Span<VectorType, typename std::enable_if_t<std::is_base_of_v<segcore::VectorTrait, VectorType>>> {
|
||||
class Span<VectorType, typename std::enable_if_t<std::is_base_of_v<VectorTrait, VectorType>>> {
|
||||
public:
|
||||
using embedded_type = typename VectorType::embedded_type;
|
||||
|
||||
|
|
|
@ -76,3 +76,5 @@ using FieldName = fluent::NamedType<std::string, struct FieldNameTag, fluent::Co
|
|||
using FieldOffset = fluent::NamedType<int64_t, struct FieldOffsetTag, fluent::Comparable, fluent::Hashable>;
|
||||
|
||||
} // namespace milvus
|
||||
|
||||
#include "VectorTrait.h"
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
#include "Types.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
class VectorTrait {};
|
||||
|
||||
class FloatVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = float;
|
||||
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
|
||||
};
|
||||
|
||||
class BinaryVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = uint8_t;
|
||||
static constexpr auto metric_type = DataType::VECTOR_BINARY;
|
||||
};
|
||||
|
||||
template <typename VectorType>
|
||||
inline constexpr int64_t
|
||||
get_element_sizeof(int64_t dim) {
|
||||
static_assert(std::is_base_of_v<VectorType, VectorTrait>);
|
||||
if constexpr (std::is_same_v<VectorType, FloatVector>) {
|
||||
return dim * sizeof(float);
|
||||
} else {
|
||||
return dim / 8;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool IsVector = std::is_base_of_v<VectorTrait, T>;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool IsScalar = std::is_fundamental_v<T>;
|
||||
|
||||
template <typename T, typename Enabled = void>
|
||||
struct EmbeddedTypeImpl;
|
||||
|
||||
template <typename T>
|
||||
struct EmbeddedTypeImpl<T, std::enable_if_t<IsScalar<T>>> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct EmbeddedTypeImpl<T, std::enable_if_t<IsVector<T>>> {
|
||||
using type = std::conditional_t<std::is_same_v<T, FloatVector>, float, uint8_t>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using EmbeddedType = typename EmbeddedTypeImpl<T>::type;
|
||||
|
||||
} // namespace milvus
|
|
@ -92,7 +92,6 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
|
|||
|
||||
final_qr.merge(sub_qr);
|
||||
}
|
||||
using segcore::FloatVector;
|
||||
auto vec_ptr = record.get_entity<FloatVector>(vecfield_offset);
|
||||
|
||||
// step 4: brute force search where small indexing is unavailable
|
||||
|
@ -165,7 +164,6 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment,
|
|||
// TODO: use QuerySubResult instead
|
||||
query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, dim, query_data};
|
||||
|
||||
using segcore::BinaryVector;
|
||||
auto vec_ptr = record.get_entity<BinaryVector>(vecfield_offset);
|
||||
|
||||
auto max_indexed_id = 0;
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "utils/tools.h"
|
||||
#include <boost/container/vector.hpp>
|
||||
#include "common/Types.h"
|
||||
#include "common/Span.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
|
@ -82,6 +83,9 @@ class VectorBase {
|
|||
virtual void
|
||||
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0;
|
||||
|
||||
virtual SpanBase
|
||||
get_span_base(int64_t chunk_id) const = 0;
|
||||
|
||||
int64_t
|
||||
get_chunk_size() const {
|
||||
return chunk_size_;
|
||||
|
@ -104,6 +108,9 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
ConcurrentVectorImpl&
|
||||
operator=(const ConcurrentVectorImpl&) = delete;
|
||||
|
||||
using TraitType =
|
||||
std::conditional_t<is_scalar, Type, std::conditional_t<std::is_same_v<Type, float>, FloatVector, BinaryVector>>;
|
||||
|
||||
public:
|
||||
explicit ConcurrentVectorImpl(ssize_t dim, int64_t chunk_size) : VectorBase(chunk_size), Dim(is_scalar ? 1 : dim) {
|
||||
Assert(is_scalar ? dim == 1 : dim != 1);
|
||||
|
@ -115,6 +122,25 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
chunks_.emplace_to_at_least(chunk_count, Dim * chunk_size_);
|
||||
}
|
||||
|
||||
Span<TraitType>
|
||||
get_span(int64_t chunk_id) const {
|
||||
auto& chunk = get_chunk(chunk_id);
|
||||
if constexpr (is_scalar) {
|
||||
return Span<TraitType>(chunk.data(), chunk_size_);
|
||||
} else if constexpr (std::is_same_v<Type, int64_t> || std::is_same_v<Type, int>) {
|
||||
// only for testing
|
||||
PanicInfo("unimplemented");
|
||||
} else {
|
||||
static_assert(std::is_same_v<typename TraitType::embedded_type, Type>);
|
||||
return Span<TraitType>(chunk.data(), chunk_size_, Dim);
|
||||
}
|
||||
}
|
||||
|
||||
SpanBase
|
||||
get_span_base(int64_t chunk_id) const override {
|
||||
return get_span(chunk_id);
|
||||
}
|
||||
|
||||
void
|
||||
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) override {
|
||||
set_data(element_offset, static_cast<const Type*>(source), element_count);
|
||||
|
@ -206,25 +232,12 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
template <typename Type>
|
||||
class ConcurrentVector : public ConcurrentVectorImpl<Type, true> {
|
||||
public:
|
||||
static_assert(std::is_fundamental_v<Type>);
|
||||
explicit ConcurrentVector(int64_t chunk_size)
|
||||
: ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl(1, chunk_size) {
|
||||
}
|
||||
};
|
||||
|
||||
class VectorTrait {};
|
||||
|
||||
class FloatVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = float;
|
||||
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
|
||||
};
|
||||
|
||||
class BinaryVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = uint8_t;
|
||||
static constexpr auto metric_type = DataType::VECTOR_BINARY;
|
||||
};
|
||||
|
||||
template <>
|
||||
class ConcurrentVector<FloatVector> : public ConcurrentVectorImpl<float, false> {
|
||||
public:
|
||||
|
|
|
@ -76,7 +76,7 @@ IndexingRecord::UpdateResourceAck(int64_t chunk_ack, const InsertRecord& record)
|
|||
// std::thread([this, old_ack, chunk_ack, &record] {
|
||||
for (auto& [field_offset, entry] : entries_) {
|
||||
auto vec_base = record.get_base_entity(field_offset);
|
||||
entry->BuildIndexRange(old_ack, chunk_ack, vec_base.get());
|
||||
entry->BuildIndexRange(old_ack, chunk_ack, vec_base);
|
||||
}
|
||||
finished_ack_.AddSegment(old_ack, chunk_ack);
|
||||
// }).detach();
|
||||
|
|
|
@ -28,7 +28,7 @@ struct InsertRecord {
|
|||
|
||||
auto
|
||||
get_base_entity(FieldOffset field_offset) const {
|
||||
auto ptr = entity_vec_[field_offset.get()];
|
||||
auto ptr = entity_vec_[field_offset.get()].get();
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ struct InsertRecord {
|
|||
auto
|
||||
get_entity(FieldOffset field_offset) const {
|
||||
auto base_ptr = get_base_entity(field_offset);
|
||||
auto ptr = std::dynamic_pointer_cast<const ConcurrentVector<Type>>(base_ptr);
|
||||
auto ptr = dynamic_cast<const ConcurrentVector<Type>*>(base_ptr);
|
||||
Assert(ptr);
|
||||
return ptr;
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ struct InsertRecord {
|
|||
auto
|
||||
get_entity(FieldOffset field_offset) {
|
||||
auto base_ptr = get_base_entity(field_offset);
|
||||
auto ptr = std::dynamic_pointer_cast<ConcurrentVector<Type>>(base_ptr);
|
||||
auto ptr = dynamic_cast<ConcurrentVector<Type>*>(base_ptr);
|
||||
Assert(ptr);
|
||||
return ptr;
|
||||
}
|
||||
|
@ -54,17 +54,17 @@ struct InsertRecord {
|
|||
void
|
||||
insert_entity(int64_t chunk_size) {
|
||||
static_assert(std::is_fundamental_v<Type>);
|
||||
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<Type>>(chunk_size));
|
||||
entity_vec_.emplace_back(std::make_unique<ConcurrentVector<Type>>(chunk_size));
|
||||
}
|
||||
|
||||
template <typename VectorType>
|
||||
void
|
||||
insert_entity(int64_t dim, int64_t chunk_size) {
|
||||
static_assert(std::is_base_of_v<VectorTrait, VectorType>);
|
||||
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<VectorType>>(dim, chunk_size));
|
||||
entity_vec_.emplace_back(std::make_unique<ConcurrentVector<VectorType>>(dim, chunk_size));
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<VectorBase>> entity_vec_;
|
||||
std::vector<std::unique_ptr<VectorBase>> entity_vec_;
|
||||
};
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -299,4 +299,16 @@ SegmentGrowingImpl::LoadIndexing(const LoadIndexInfo& info) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
SpanBase
|
||||
SegmentGrowingImpl::chunk_data_impl(FieldOffset field_offset, int64_t chunk_id) const {
|
||||
auto vec = get_insert_record().get_base_entity(field_offset);
|
||||
return vec->get_span_base(chunk_id);
|
||||
}
|
||||
|
||||
int64_t
|
||||
SegmentGrowingImpl::get_safe_num_chunk() const {
|
||||
auto size = get_insert_record().ack_responder_.GetAck();
|
||||
return upper_div(size, chunk_size_);
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -112,9 +112,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
}
|
||||
|
||||
int64_t
|
||||
get_num_chunk() const override {
|
||||
PanicInfo("unimplemented");
|
||||
}
|
||||
get_safe_num_chunk() const override;
|
||||
|
||||
Status
|
||||
LoadIndexing(const LoadIndexInfo& info) override;
|
||||
|
@ -139,9 +137,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
|
||||
protected:
|
||||
SpanBase
|
||||
chunk_data_impl(FieldOffset field_offset, int64_t chunk_id) const override {
|
||||
PanicInfo("unimplemented");
|
||||
}
|
||||
chunk_data_impl(FieldOffset field_offset, int64_t chunk_id) const override;
|
||||
|
||||
private:
|
||||
int64_t chunk_size_;
|
||||
|
|
|
@ -44,16 +44,16 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||
get_schema() const = 0;
|
||||
|
||||
virtual int64_t
|
||||
get_num_chunk() const = 0;
|
||||
get_safe_num_chunk() const = 0;
|
||||
|
||||
template <typename T>
|
||||
Span<T>
|
||||
chunk_data(FieldOffset field_offset, int64_t chunk_id) const {
|
||||
auto span = chunk_data_impl(field_offset, chunk_id);
|
||||
return static_cast<Span<T>>(span);
|
||||
return static_cast<Span<T>>(chunk_data_impl(field_offset, chunk_id));
|
||||
}
|
||||
|
||||
protected:
|
||||
// blob and row_count
|
||||
virtual SpanBase
|
||||
chunk_data_impl(FieldOffset field_offset, int64_t chunk_id) const = 0;
|
||||
};
|
||||
|
|
|
@ -17,6 +17,7 @@ set(MILVUS_TEST_FILES
|
|||
test_sealed.cpp
|
||||
test_reduce.cpp
|
||||
test_interface.cpp
|
||||
test_span.cpp
|
||||
)
|
||||
add_executable(all_tests
|
||||
${MILVUS_TEST_FILES}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "utils/tools.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
#include "segcore/SegmentGrowing.h"
|
||||
|
||||
TEST(Span, Naive) {
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
int64_t N = 1000 * 1000;
|
||||
constexpr int64_t chunk_size = 32 * 1024;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddDebugField("fakevec", DataType::VECTOR_BINARY, 512, MetricType::METRIC_Jaccard);
|
||||
schema->AddDebugField("age", DataType::FLOAT);
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateGrowingSegment(schema, chunk_size);
|
||||
segment->PreInsert(N);
|
||||
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
|
||||
auto vec_ptr = dataset.get_col<uint8_t>(0);
|
||||
auto age_ptr = dataset.get_col<float>(1);
|
||||
SegmentInternalInterface& interface = *segment;
|
||||
auto num_chunk = interface.get_safe_num_chunk();
|
||||
ASSERT_EQ(num_chunk, upper_div(N, chunk_size));
|
||||
auto row_count = interface.get_row_count();
|
||||
ASSERT_EQ(N, row_count);
|
||||
for (auto chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
|
||||
auto vec_span = interface.chunk_data<BinaryVector>(FieldOffset(0), chunk_id);
|
||||
auto age_span = interface.chunk_data<float>(FieldOffset(1), chunk_id);
|
||||
auto begin = chunk_id * chunk_size;
|
||||
auto end = std::min((chunk_id + 1) * chunk_size, N);
|
||||
auto chunk_size = end - begin;
|
||||
for (int i = 0; i < chunk_size * 512 / 8; ++i) {
|
||||
ASSERT_EQ(vec_span.data()[i], vec_ptr[i + begin * 512 / 8]);
|
||||
}
|
||||
for (int i = 0; i < chunk_size; ++i) {
|
||||
ASSERT_EQ(age_span.data()[i], age_ptr[i + begin]);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue