mirror of https://github.com/milvus-io/milvus.git
Fix wrong distances caused by metric type (#11901)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/11908/head
parent
a4f1c2986a
commit
8c951217ee
|
@ -0,0 +1,24 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "index/thirdparty/faiss/MetricType.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace segcore {
|
||||
static inline bool
|
||||
PositivelyRelated(faiss::MetricType metric_type) {
|
||||
return metric_type == faiss::MetricType::METRIC_INNER_PRODUCT || metric_type == faiss::MetricType::METRIC_Jaccard ||
|
||||
metric_type == faiss::MetricType::METRIC_Tanimoto;
|
||||
}
|
||||
} // namespace segcore
|
||||
} // namespace milvus
|
|
@ -19,6 +19,7 @@
|
|||
#include "segcore/SegmentGrowing.h"
|
||||
#include "segcore/SegmentSealed.h"
|
||||
#include "segcore/segment_c.h"
|
||||
#include "segcore/SimilarityCorelation.h"
|
||||
|
||||
////////////////////////////// common interfaces //////////////////////////////
|
||||
CSegmentInterface
|
||||
|
@ -67,7 +68,8 @@ Search(CSegmentInterface c_segment,
|
|||
auto plan = (milvus::query::Plan*)c_plan;
|
||||
auto phg_ptr = reinterpret_cast<const milvus::query::PlaceholderGroup*>(c_placeholder_group);
|
||||
*search_result = segment->Search(plan, *phg_ptr, timestamp);
|
||||
if (plan->plan_node_->search_info_.metric_type_ != milvus::MetricType::METRIC_INNER_PRODUCT) {
|
||||
// if (plan->plan_node_->search_info_.metric_type_ != milvus::MetricType::METRIC_INNER_PRODUCT) {
|
||||
if (!milvus::segcore::PositivelyRelated(plan->plan_node_->search_info_.metric_type_)) {
|
||||
for (auto& dis : search_result->result_distances_) {
|
||||
dis *= -1;
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ set(MILVUS_TEST_FILES
|
|||
test_timestamp_index.cpp
|
||||
test_reduce_c.cpp
|
||||
test_conf_adapter_mgr.cpp
|
||||
test_similarity_corelation.cpp
|
||||
)
|
||||
|
||||
add_executable(all_tests
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "segcore/SimilarityCorelation.h"
|
||||
|
||||
TEST(SimilarityCorelation, Naive) {
|
||||
ASSERT_TRUE(milvus::segcore::PositivelyRelated(faiss::METRIC_INNER_PRODUCT));
|
||||
ASSERT_TRUE(milvus::segcore::PositivelyRelated(faiss::METRIC_Jaccard));
|
||||
ASSERT_TRUE(milvus::segcore::PositivelyRelated(faiss::METRIC_Tanimoto));
|
||||
|
||||
ASSERT_FALSE(milvus::segcore::PositivelyRelated(faiss::METRIC_L2));
|
||||
ASSERT_FALSE(milvus::segcore::PositivelyRelated(faiss::METRIC_Hamming));
|
||||
ASSERT_FALSE(milvus::segcore::PositivelyRelated(faiss::METRIC_Substructure));
|
||||
ASSERT_FALSE(milvus::segcore::PositivelyRelated(faiss::METRIC_Superstructure));
|
||||
}
|
|
@ -30,6 +30,8 @@ import (
|
|||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/distance"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
@ -1831,7 +1833,8 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
|
|||
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
ret.Results.TopK = realTopK
|
||||
|
||||
if metricType != "IP" {
|
||||
// if metricType != "IP" {
|
||||
if !distance.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
}
|
||||
|
|
|
@ -26,6 +26,12 @@ const (
|
|||
HAMMING = "HAMMING"
|
||||
// TANIMOTO represents the tanimoto distance
|
||||
TANIMOTO = "TANIMOTO"
|
||||
// JACCARD
|
||||
JACCARD = "JACCARD"
|
||||
// SUPERSTRUCTURE
|
||||
SUPERSTRUCTURE = "SUPERSTRUCTURE"
|
||||
// SUBSTRUCTURE
|
||||
SUBSTRUCTURE = "SUBSTRUCTURE"
|
||||
)
|
||||
|
||||
// ValidateMetricType returns metric text or error
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package distance
|
||||
|
||||
import "strings"
|
||||
|
||||
func PositivelyRelated(metricType string) bool {
|
||||
mUpper := strings.ToUpper(metricType)
|
||||
return mUpper == strings.ToUpper(IP) ||
|
||||
mUpper == strings.ToUpper(JACCARD) ||
|
||||
mUpper == strings.ToUpper(TANIMOTO)
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package distance
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPositivelyRelated(t *testing.T) {
|
||||
cases := []struct {
|
||||
metricType string
|
||||
wanted bool
|
||||
}{
|
||||
{
|
||||
IP,
|
||||
true,
|
||||
},
|
||||
{
|
||||
JACCARD,
|
||||
true,
|
||||
},
|
||||
{
|
||||
TANIMOTO,
|
||||
true,
|
||||
},
|
||||
{
|
||||
L2,
|
||||
false,
|
||||
},
|
||||
{
|
||||
HAMMING,
|
||||
false,
|
||||
},
|
||||
{
|
||||
SUPERSTRUCTURE,
|
||||
false,
|
||||
},
|
||||
{
|
||||
SUBSTRUCTURE,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
if got := PositivelyRelated(cases[idx].metricType); got != cases[idx].wanted {
|
||||
t.Errorf("PositivelyRelated(%v) = %v", cases[idx].metricType, cases[idx].wanted)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue