Fix wrong distances caused by metric type (#11901)

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/11908/head
dragondriver 2021-11-16 19:11:10 +08:00 committed by GitHub
parent a4f1c2986a
commit 8c951217ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 140 additions and 2 deletions

View File

@ -0,0 +1,24 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "index/thirdparty/faiss/MetricType.h"
namespace milvus {
namespace segcore {
static inline bool
PositivelyRelated(faiss::MetricType metric_type) {
return metric_type == faiss::MetricType::METRIC_INNER_PRODUCT || metric_type == faiss::MetricType::METRIC_Jaccard ||
metric_type == faiss::MetricType::METRIC_Tanimoto;
}
} // namespace segcore
} // namespace milvus

View File

@ -19,6 +19,7 @@
#include "segcore/SegmentGrowing.h"
#include "segcore/SegmentSealed.h"
#include "segcore/segment_c.h"
#include "segcore/SimilarityCorelation.h"
////////////////////////////// common interfaces //////////////////////////////
CSegmentInterface
@ -67,7 +68,8 @@ Search(CSegmentInterface c_segment,
auto plan = (milvus::query::Plan*)c_plan;
auto phg_ptr = reinterpret_cast<const milvus::query::PlaceholderGroup*>(c_placeholder_group);
*search_result = segment->Search(plan, *phg_ptr, timestamp);
if (plan->plan_node_->search_info_.metric_type_ != milvus::MetricType::METRIC_INNER_PRODUCT) {
// if (plan->plan_node_->search_info_.metric_type_ != milvus::MetricType::METRIC_INNER_PRODUCT) {
if (!milvus::segcore::PositivelyRelated(plan->plan_node_->search_info_.metric_type_)) {
for (auto& dis : search_result->result_distances_) {
dis *= -1;
}

View File

@ -36,6 +36,7 @@ set(MILVUS_TEST_FILES
test_timestamp_index.cpp
test_reduce_c.cpp
test_conf_adapter_mgr.cpp
test_similarity_corelation.cpp
)
add_executable(all_tests

View File

@ -0,0 +1,25 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include "segcore/SimilarityCorelation.h"
TEST(SimilarityCorelation, Naive) {
ASSERT_TRUE(milvus::segcore::PositivelyRelated(faiss::METRIC_INNER_PRODUCT));
ASSERT_TRUE(milvus::segcore::PositivelyRelated(faiss::METRIC_Jaccard));
ASSERT_TRUE(milvus::segcore::PositivelyRelated(faiss::METRIC_Tanimoto));
ASSERT_FALSE(milvus::segcore::PositivelyRelated(faiss::METRIC_L2));
ASSERT_FALSE(milvus::segcore::PositivelyRelated(faiss::METRIC_Hamming));
ASSERT_FALSE(milvus::segcore::PositivelyRelated(faiss::METRIC_Substructure));
ASSERT_FALSE(milvus::segcore::PositivelyRelated(faiss::METRIC_Superstructure));
}

View File

@ -30,6 +30,8 @@ import (
"strings"
"unsafe"
"github.com/milvus-io/milvus/internal/util/distance"
"go.uber.org/zap"
"github.com/golang/protobuf/proto"
@ -1831,7 +1833,8 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
ret.Results.TopK = realTopK
if metricType != "IP" {
// if metricType != "IP" {
if !distance.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}

View File

@ -26,6 +26,12 @@ const (
HAMMING = "HAMMING"
// TANIMOTO represents the tanimoto distance
TANIMOTO = "TANIMOTO"
// JACCARD
JACCARD = "JACCARD"
// SUPERSTRUCTURE
SUPERSTRUCTURE = "SUPERSTRUCTURE"
// SUBSTRUCTURE
SUBSTRUCTURE = "SUBSTRUCTURE"
)
// ValidateMetricType returns metric text or error

View File

@ -0,0 +1,21 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package distance
import "strings"
func PositivelyRelated(metricType string) bool {
mUpper := strings.ToUpper(metricType)
return mUpper == strings.ToUpper(IP) ||
mUpper == strings.ToUpper(JACCARD) ||
mUpper == strings.ToUpper(TANIMOTO)
}

View File

@ -0,0 +1,56 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package distance
import "testing"
func TestPositivelyRelated(t *testing.T) {
cases := []struct {
metricType string
wanted bool
}{
{
IP,
true,
},
{
JACCARD,
true,
},
{
TANIMOTO,
true,
},
{
L2,
false,
},
{
HAMMING,
false,
},
{
SUPERSTRUCTURE,
false,
},
{
SUBSTRUCTURE,
false,
},
}
for idx := range cases {
if got := PositivelyRelated(cases[idx].metricType); got != cases[idx].wanted {
t.Errorf("PositivelyRelated(%v) = %v", cases[idx].metricType, cases[idx].wanted)
}
}
}