mirror of https://github.com/milvus-io/milvus.git
Fix bug: override the compare function of SearchResultPair (#6628)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/6623/head^2
parent
ab5a7cbf44
commit
99249a0224
|
@ -0,0 +1,40 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <cmath> // std::isnan
|
||||
#include <common/Types.h>
|
||||
#include "segcore/Reduce.h"
|
||||
|
||||
struct SearchResultPair {
|
||||
float distance_;
|
||||
milvus::SearchResult* search_result_;
|
||||
int64_t offset_;
|
||||
int64_t index_;
|
||||
|
||||
SearchResultPair(float distance, milvus::SearchResult* search_result, int64_t offset, int64_t index)
|
||||
: distance_(distance), search_result_(search_result), offset_(offset), index_(index) {
|
||||
}
|
||||
|
||||
bool
|
||||
operator<(const SearchResultPair& pair) const {
|
||||
return std::isnan(pair.distance_) || (!std::isnan(distance_) && (distance_ < pair.distance_));
|
||||
}
|
||||
|
||||
bool
|
||||
operator>(const SearchResultPair& pair) const {
|
||||
return std::isnan(pair.distance_) || (!std::isnan(distance_) && (distance_ > pair.distance_));
|
||||
}
|
||||
|
||||
void
|
||||
reset_distance() {
|
||||
distance_ = search_result_->result_distances_[offset_];
|
||||
}
|
||||
};
|
|
@ -14,6 +14,7 @@
|
|||
#include "segcore/reduce_c.h"
|
||||
|
||||
#include "segcore/Reduce.h"
|
||||
#include "segcore/ReduceStructure.h"
|
||||
#include "common/Types.h"
|
||||
#include "pb/milvus.pb.h"
|
||||
|
||||
|
@ -49,32 +50,6 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
|
|||
delete hits;
|
||||
}
|
||||
|
||||
struct SearchResultPair {
|
||||
float distance_;
|
||||
SearchResult* search_result_;
|
||||
int64_t offset_;
|
||||
int64_t index_;
|
||||
|
||||
SearchResultPair(float distance, SearchResult* search_result, int64_t offset, int64_t index)
|
||||
: distance_(distance), search_result_(search_result), offset_(offset), index_(index) {
|
||||
}
|
||||
|
||||
bool
|
||||
operator<(const SearchResultPair& pair) const {
|
||||
return (distance_ < pair.distance_);
|
||||
}
|
||||
|
||||
bool
|
||||
operator>(const SearchResultPair& pair) const {
|
||||
return (distance_ > pair.distance_);
|
||||
}
|
||||
|
||||
void
|
||||
reset_distance() {
|
||||
distance_ = search_result_->result_distances_[offset_];
|
||||
}
|
||||
};
|
||||
|
||||
void
|
||||
GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
||||
std::vector<SearchResult*>& search_results,
|
||||
|
|
|
@ -37,6 +37,7 @@ set(MILVUS_TEST_FILES
|
|||
test_segcore.cpp
|
||||
test_span.cpp
|
||||
test_timestamp_index.cpp
|
||||
test_reduce_c.cpp
|
||||
)
|
||||
|
||||
add_executable(all_tests
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "segcore/ReduceStructure.h"
|
||||
|
||||
TEST(SearchResultPair, Less) {
|
||||
auto pair1 = SearchResultPair(1.0, nullptr, 0, 0);
|
||||
auto pair2 = SearchResultPair(1.0, nullptr, 0, 0);
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = 1.0;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 < pair2, true);
|
||||
ASSERT_EQ(pair1.operator<(pair2), true);
|
||||
|
||||
pair1.distance_ = 1.0;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 < pair2, true);
|
||||
ASSERT_EQ(pair1.operator<(pair2), true);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = 1.0;
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 < pair2, true);
|
||||
ASSERT_EQ(pair1.operator<(pair2), true);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = 1.0;
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 < pair2, true);
|
||||
ASSERT_EQ(pair1.operator<(pair2), true);
|
||||
}
|
||||
|
||||
TEST(SearchResultPair, Greater) {
|
||||
auto pair1 = SearchResultPair(1.0, nullptr, 0, 0);
|
||||
auto pair2 = SearchResultPair(1.0, nullptr, 0, 0);
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = 1.0;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = 1.0;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 > pair2, true);
|
||||
ASSERT_EQ(pair1.operator>(pair2), true);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = 1.0;
|
||||
ASSERT_EQ(pair1 > pair2, true);
|
||||
ASSERT_EQ(pair1.operator>(pair2), true);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 > pair2, true);
|
||||
ASSERT_EQ(pair1.operator>(pair2), true);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = 1.0;
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 > pair2, true);
|
||||
ASSERT_EQ(pair1.operator>(pair2), true);
|
||||
}
|
Loading…
Reference in New Issue