mirror of https://github.com/milvus-io/milvus.git
				
				
				
			test: update restful v2 test cases (#36448)
/kind improvement Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/36367/head^2
							parent
							
								
									ddadefcb62
								
							
						
					
					
						commit
						4779c6cb8f
					
				| 
						 | 
					@ -101,7 +101,8 @@ class TestBase(Base):
 | 
				
			||||||
        batch_size = batch_size
 | 
					        batch_size = batch_size
 | 
				
			||||||
        batch = nb // batch_size
 | 
					        batch = nb // batch_size
 | 
				
			||||||
        remainder = nb % batch_size
 | 
					        remainder = nb % batch_size
 | 
				
			||||||
        data = []
 | 
					
 | 
				
			||||||
 | 
					        full_data = []
 | 
				
			||||||
        insert_ids = []
 | 
					        insert_ids = []
 | 
				
			||||||
        for i in range(batch):
 | 
					        for i in range(batch):
 | 
				
			||||||
            nb = batch_size
 | 
					            nb = batch_size
 | 
				
			||||||
| 
						 | 
					@ -116,6 +117,7 @@ class TestBase(Base):
 | 
				
			||||||
            assert rsp['code'] == 0
 | 
					            assert rsp['code'] == 0
 | 
				
			||||||
            if return_insert_id:
 | 
					            if return_insert_id:
 | 
				
			||||||
                insert_ids.extend(rsp['data']['insertIds'])
 | 
					                insert_ids.extend(rsp['data']['insertIds'])
 | 
				
			||||||
 | 
					            full_data.extend(data)
 | 
				
			||||||
        # insert remainder data
 | 
					        # insert remainder data
 | 
				
			||||||
        if remainder:
 | 
					        if remainder:
 | 
				
			||||||
            nb = remainder
 | 
					            nb = remainder
 | 
				
			||||||
| 
						 | 
					@ -128,10 +130,11 @@ class TestBase(Base):
 | 
				
			||||||
            assert rsp['code'] == 0
 | 
					            assert rsp['code'] == 0
 | 
				
			||||||
            if return_insert_id:
 | 
					            if return_insert_id:
 | 
				
			||||||
                insert_ids.extend(rsp['data']['insertIds'])
 | 
					                insert_ids.extend(rsp['data']['insertIds'])
 | 
				
			||||||
 | 
					            full_data.extend(data)
 | 
				
			||||||
        if return_insert_id:
 | 
					        if return_insert_id:
 | 
				
			||||||
            return schema_payload, data, insert_ids
 | 
					            return schema_payload, full_data, insert_ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return schema_payload, data
 | 
					        return schema_payload, full_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def wait_collection_load_completed(self, name):
 | 
					    def wait_collection_load_completed(self, name):
 | 
				
			||||||
        t0 = time.time()
 | 
					        t0 = time.time()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,8 +4,10 @@ import numpy as np
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import utils.utils
 | 
				
			||||||
from utils import constant
 | 
					from utils import constant
 | 
				
			||||||
from utils.utils import gen_collection_name
 | 
					from utils.utils import gen_collection_name, get_sorted_distance
 | 
				
			||||||
from utils.util_log import test_log as logger
 | 
					from utils.util_log import test_log as logger
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from base.testbase import TestBase
 | 
					from base.testbase import TestBase
 | 
				
			||||||
| 
						 | 
					@ -921,12 +923,10 @@ class TestUpsertVector(TestBase):
 | 
				
			||||||
@pytest.mark.L0
 | 
					@pytest.mark.L0
 | 
				
			||||||
class TestSearchVector(TestBase):
 | 
					class TestSearchVector(TestBase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @pytest.mark.parametrize("insert_round", [1])
 | 
					    @pytest.mark.parametrize("insert_round", [1])
 | 
				
			||||||
    @pytest.mark.parametrize("auto_id", [True])
 | 
					    @pytest.mark.parametrize("auto_id", [True])
 | 
				
			||||||
    @pytest.mark.parametrize("is_partition_key", [True])
 | 
					    @pytest.mark.parametrize("is_partition_key", [True])
 | 
				
			||||||
    @pytest.mark.parametrize("enable_dynamic_schema", [True])
 | 
					    @pytest.mark.parametrize("enable_dynamic_schema", [True])
 | 
				
			||||||
    @pytest.mark.skip(reason="behavior change;todo:@zhuwenxing")
 | 
					 | 
				
			||||||
    @pytest.mark.parametrize("nb", [3000])
 | 
					    @pytest.mark.parametrize("nb", [3000])
 | 
				
			||||||
    @pytest.mark.parametrize("dim", [16])
 | 
					    @pytest.mark.parametrize("dim", [16])
 | 
				
			||||||
    def test_search_vector_with_all_vector_datatype(self, nb, dim, insert_round, auto_id,
 | 
					    def test_search_vector_with_all_vector_datatype(self, nb, dim, insert_round, auto_id,
 | 
				
			||||||
| 
						 | 
					@ -1011,14 +1011,7 @@ class TestSearchVector(TestBase):
 | 
				
			||||||
            "filter": "word_count > 100",
 | 
					            "filter": "word_count > 100",
 | 
				
			||||||
            "groupingField": "user_id",
 | 
					            "groupingField": "user_id",
 | 
				
			||||||
            "outputFields": ["*"],
 | 
					            "outputFields": ["*"],
 | 
				
			||||||
            "searchParams": {
 | 
					            "limit": 100
 | 
				
			||||||
                "metricType": "COSINE",
 | 
					 | 
				
			||||||
                "params": {
 | 
					 | 
				
			||||||
                    "radius": "0.1",
 | 
					 | 
				
			||||||
                    "range_filter": "0.8"
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            "limit": 100,
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        rsp = self.vector_client.vector_search(payload)
 | 
					        rsp = self.vector_client.vector_search(payload)
 | 
				
			||||||
        assert rsp['code'] == 0
 | 
					        assert rsp['code'] == 0
 | 
				
			||||||
| 
						 | 
					@ -1032,10 +1025,10 @@ class TestSearchVector(TestBase):
 | 
				
			||||||
    @pytest.mark.parametrize("enable_dynamic_schema", [True])
 | 
					    @pytest.mark.parametrize("enable_dynamic_schema", [True])
 | 
				
			||||||
    @pytest.mark.parametrize("nb", [3000])
 | 
					    @pytest.mark.parametrize("nb", [3000])
 | 
				
			||||||
    @pytest.mark.parametrize("dim", [128])
 | 
					    @pytest.mark.parametrize("dim", [128])
 | 
				
			||||||
    @pytest.mark.skip(reason="behavior change;todo:@zhuwenxing")
 | 
					 | 
				
			||||||
    @pytest.mark.parametrize("nq", [1, 2])
 | 
					    @pytest.mark.parametrize("nq", [1, 2])
 | 
				
			||||||
 | 
					    @pytest.mark.parametrize("metric_type", ['COSINE', "L2", "IP"])
 | 
				
			||||||
    def test_search_vector_with_float_vector_datatype(self, nb, dim, insert_round, auto_id,
 | 
					    def test_search_vector_with_float_vector_datatype(self, nb, dim, insert_round, auto_id,
 | 
				
			||||||
                                                      is_partition_key, enable_dynamic_schema, nq):
 | 
					                                                      is_partition_key, enable_dynamic_schema, nq, metric_type):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Insert a vector with a simple payload
 | 
					        Insert a vector with a simple payload
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -1056,7 +1049,7 @@ class TestSearchVector(TestBase):
 | 
				
			||||||
                ]
 | 
					                ]
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
            "indexParams": [
 | 
					            "indexParams": [
 | 
				
			||||||
                {"fieldName": "float_vector", "indexName": "float_vector", "metricType": "COSINE"},
 | 
					                {"fieldName": "float_vector", "indexName": "float_vector", "metricType": metric_type},
 | 
				
			||||||
            ]
 | 
					            ]
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        rsp = self.collection_client.collection_create(payload)
 | 
					        rsp = self.collection_client.collection_create(payload)
 | 
				
			||||||
| 
						 | 
					@ -1100,13 +1093,6 @@ class TestSearchVector(TestBase):
 | 
				
			||||||
            "filter": "word_count > 100",
 | 
					            "filter": "word_count > 100",
 | 
				
			||||||
            "groupingField": "user_id",
 | 
					            "groupingField": "user_id",
 | 
				
			||||||
            "outputFields": ["*"],
 | 
					            "outputFields": ["*"],
 | 
				
			||||||
            "searchParams": {
 | 
					 | 
				
			||||||
                "metricType": "COSINE",
 | 
					 | 
				
			||||||
                "params": {
 | 
					 | 
				
			||||||
                    "radius": "0.1",
 | 
					 | 
				
			||||||
                    "range_filter": "0.8"
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            "limit": 100,
 | 
					            "limit": 100,
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        rsp = self.vector_client.vector_search(payload)
 | 
					        rsp = self.vector_client.vector_search(payload)
 | 
				
			||||||
| 
						 | 
					@ -1227,8 +1213,8 @@ class TestSearchVector(TestBase):
 | 
				
			||||||
    @pytest.mark.parametrize("enable_dynamic_schema", [True])
 | 
					    @pytest.mark.parametrize("enable_dynamic_schema", [True])
 | 
				
			||||||
    @pytest.mark.parametrize("nb", [3000])
 | 
					    @pytest.mark.parametrize("nb", [3000])
 | 
				
			||||||
    @pytest.mark.parametrize("dim", [128])
 | 
					    @pytest.mark.parametrize("dim", [128])
 | 
				
			||||||
    @pytest.mark.skip(reason="behavior change;todo:@zhuwenxing")
 | 
					    @pytest.mark.parametrize("metric_type", ['HAMMING'])
 | 
				
			||||||
    def test_search_vector_with_binary_vector_datatype(self, nb, dim, insert_round, auto_id,
 | 
					    def test_search_vector_with_binary_vector_datatype(self, metric_type, nb, dim, insert_round, auto_id,
 | 
				
			||||||
                                                      is_partition_key, enable_dynamic_schema):
 | 
					                                                      is_partition_key, enable_dynamic_schema):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Insert a vector with a simple payload
 | 
					        Insert a vector with a simple payload
 | 
				
			||||||
| 
						 | 
					@ -1250,7 +1236,7 @@ class TestSearchVector(TestBase):
 | 
				
			||||||
                ]
 | 
					                ]
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
            "indexParams": [
 | 
					            "indexParams": [
 | 
				
			||||||
                {"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": "HAMMING",
 | 
					                {"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": metric_type,
 | 
				
			||||||
                 "params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
 | 
					                 "params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
 | 
				
			||||||
            ]
 | 
					            ]
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
| 
						 | 
					@ -1301,13 +1287,6 @@ class TestSearchVector(TestBase):
 | 
				
			||||||
            "data": [gen_vector(datatype="BinaryVector", dim=dim)],
 | 
					            "data": [gen_vector(datatype="BinaryVector", dim=dim)],
 | 
				
			||||||
            "filter": "word_count > 100",
 | 
					            "filter": "word_count > 100",
 | 
				
			||||||
            "outputFields": ["*"],
 | 
					            "outputFields": ["*"],
 | 
				
			||||||
            "searchParams": {
 | 
					 | 
				
			||||||
                "metricType": "HAMMING",
 | 
					 | 
				
			||||||
                "params": {
 | 
					 | 
				
			||||||
                    "radius": "0.1",
 | 
					 | 
				
			||||||
                    "range_filter": "0.8"
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            "limit": 100,
 | 
					            "limit": 100,
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        rsp = self.vector_client.vector_search(payload)
 | 
					        rsp = self.vector_client.vector_search(payload)
 | 
				
			||||||
| 
						 | 
					@ -1549,6 +1528,130 @@ class TestSearchVector(TestBase):
 | 
				
			||||||
            if "like" in varchar_expr:
 | 
					            if "like" in varchar_expr:
 | 
				
			||||||
                assert name.startswith(prefix)
 | 
					                assert name.startswith(prefix)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @pytest.mark.parametrize("consistency_level", ["Strong", "Bounded", "Eventually", "Session"])
 | 
				
			||||||
 | 
					    def test_search_vector_with_consistency_level(self, consistency_level):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Search a vector with different consistency level
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        name = gen_collection_name()
 | 
				
			||||||
 | 
					        self.name = name
 | 
				
			||||||
 | 
					        nb = 200
 | 
				
			||||||
 | 
					        dim = 128
 | 
				
			||||||
 | 
					        limit = 100
 | 
				
			||||||
 | 
					        schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
 | 
				
			||||||
 | 
					        names = []
 | 
				
			||||||
 | 
					        for item in data:
 | 
				
			||||||
 | 
					            names.append(item.get("name"))
 | 
				
			||||||
 | 
					        names.sort()
 | 
				
			||||||
 | 
					        logger.info(f"names: {names}")
 | 
				
			||||||
 | 
					        mid = len(names) // 2
 | 
				
			||||||
 | 
					        prefix = names[mid][0:2]
 | 
				
			||||||
 | 
					        vector_field = schema_payload.get("vectorField")
 | 
				
			||||||
 | 
					        # search data
 | 
				
			||||||
 | 
					        vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
 | 
				
			||||||
 | 
					        output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
 | 
				
			||||||
 | 
					        payload = {
 | 
				
			||||||
 | 
					            "collectionName": name,
 | 
				
			||||||
 | 
					            "data": [vector_to_search],
 | 
				
			||||||
 | 
					            "outputFields": output_fields,
 | 
				
			||||||
 | 
					            "limit": limit,
 | 
				
			||||||
 | 
					            "offset": 0,
 | 
				
			||||||
 | 
					            "consistencyLevel": consistency_level
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        rsp = self.vector_client.vector_search(payload)
 | 
				
			||||||
 | 
					        assert rsp['code'] == 0
 | 
				
			||||||
 | 
					        res = rsp['data']
 | 
				
			||||||
 | 
					        logger.info(f"res: {len(res)}")
 | 
				
			||||||
 | 
					        assert len(res) == limit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @pytest.mark.parametrize("metric_type", ["L2", "COSINE", "IP"])
 | 
				
			||||||
 | 
					    def test_search_vector_with_range_search(self, metric_type):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Search a vector with range search with different metric type
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        name = gen_collection_name()
 | 
				
			||||||
 | 
					        self.name = name
 | 
				
			||||||
 | 
					        nb = 3000
 | 
				
			||||||
 | 
					        dim = 128
 | 
				
			||||||
 | 
					        limit = 100
 | 
				
			||||||
 | 
					        schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type)
 | 
				
			||||||
 | 
					        vector_field = schema_payload.get("vectorField")
 | 
				
			||||||
 | 
					        # search data
 | 
				
			||||||
 | 
					        vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
 | 
				
			||||||
 | 
					        training_data = [item[vector_field] for item in data]
 | 
				
			||||||
 | 
					        distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type)
 | 
				
			||||||
 | 
					        r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct
 | 
				
			||||||
 | 
					        if metric_type == "L2":
 | 
				
			||||||
 | 
					            r1, r2 = r2, r1
 | 
				
			||||||
 | 
					        output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
 | 
				
			||||||
 | 
					        payload = {
 | 
				
			||||||
 | 
					            "collectionName": name,
 | 
				
			||||||
 | 
					            "data": [vector_to_search],
 | 
				
			||||||
 | 
					            "outputFields": output_fields,
 | 
				
			||||||
 | 
					            "limit": limit,
 | 
				
			||||||
 | 
					            "offset": 0,
 | 
				
			||||||
 | 
					            "searchParams": {
 | 
				
			||||||
 | 
					                "params": {
 | 
				
			||||||
 | 
					                    "radius": r1,
 | 
				
			||||||
 | 
					                    "range_filter": r2,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        rsp = self.vector_client.vector_search(payload)
 | 
				
			||||||
 | 
					        assert rsp['code'] == 0
 | 
				
			||||||
 | 
					        res = rsp['data']
 | 
				
			||||||
 | 
					        logger.info(f"res: {len(res)}")
 | 
				
			||||||
 | 
					        assert len(res) == limit
 | 
				
			||||||
 | 
					        for item in res:
 | 
				
			||||||
 | 
					            distance = item.get("distance")
 | 
				
			||||||
 | 
					            if metric_type == "L2":
 | 
				
			||||||
 | 
					                assert r1 > distance > r2
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                assert r1 < distance < r2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @pytest.mark.parametrize("ignore_growing", [True, False])
 | 
				
			||||||
 | 
					    def test_search_vector_with_ignore_growing(self, ignore_growing):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Search a vector with range search with different metric type
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        name = gen_collection_name()
 | 
				
			||||||
 | 
					        self.name = name
 | 
				
			||||||
 | 
					        metric_type = "COSINE"
 | 
				
			||||||
 | 
					        nb = 1000
 | 
				
			||||||
 | 
					        dim = 128
 | 
				
			||||||
 | 
					        limit = 100
 | 
				
			||||||
 | 
					        schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type)
 | 
				
			||||||
 | 
					        vector_field = schema_payload.get("vectorField")
 | 
				
			||||||
 | 
					        # search data
 | 
				
			||||||
 | 
					        vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
 | 
				
			||||||
 | 
					        training_data = [item[vector_field] for item in data]
 | 
				
			||||||
 | 
					        distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type)
 | 
				
			||||||
 | 
					        r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct
 | 
				
			||||||
 | 
					        if metric_type == "L2":
 | 
				
			||||||
 | 
					            r1, r2 = r2, r1
 | 
				
			||||||
 | 
					        output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        payload = {
 | 
				
			||||||
 | 
					            "collectionName": name,
 | 
				
			||||||
 | 
					            "data": [vector_to_search],
 | 
				
			||||||
 | 
					            "outputFields": output_fields,
 | 
				
			||||||
 | 
					            "limit": limit,
 | 
				
			||||||
 | 
					            "offset": 0,
 | 
				
			||||||
 | 
					            "searchParams": {
 | 
				
			||||||
 | 
					                "ignore_growing": ignore_growing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        rsp = self.vector_client.vector_search(payload)
 | 
				
			||||||
 | 
					        assert rsp['code'] == 0
 | 
				
			||||||
 | 
					        res = rsp['data']
 | 
				
			||||||
 | 
					        logger.info(f"res: {len(res)}")
 | 
				
			||||||
 | 
					        if ignore_growing is True:
 | 
				
			||||||
 | 
					            assert len(res) == 0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            assert len(res) == limit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.L1
 | 
					@pytest.mark.L1
 | 
				
			||||||
class TestSearchVectorNegative(TestBase):
 | 
					class TestSearchVectorNegative(TestBase):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,7 +10,7 @@ import base64
 | 
				
			||||||
import requests
 | 
					import requests
 | 
				
			||||||
from loguru import logger
 | 
					from loguru import logger
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
 | 
					from sklearn.metrics import pairwise_distances
 | 
				
			||||||
fake = Faker()
 | 
					fake = Faker()
 | 
				
			||||||
rng = np.random.default_rng()
 | 
					rng = np.random.default_rng()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -240,4 +240,28 @@ def get_all_fields_by_data(data, exclude_fields=None):
 | 
				
			||||||
    return list(fields)
 | 
					    return list(fields)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ip_distance(x, y):
 | 
				
			||||||
 | 
					    return np.dot(x, y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cosine_distance(u, v, epsilon=1e-8):
 | 
				
			||||||
 | 
					    dot_product = np.dot(u, v)
 | 
				
			||||||
 | 
					    norm_u = np.linalg.norm(u)
 | 
				
			||||||
 | 
					    norm_v = np.linalg.norm(v)
 | 
				
			||||||
 | 
					    return dot_product / (max(norm_u * norm_v, epsilon))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def l2_distance(u, v):
 | 
				
			||||||
 | 
					    return np.sum((u - v) ** 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_sorted_distance(train_emb, test_emb, metric_type):
 | 
				
			||||||
 | 
					    milvus_sklearn_metric_map = {
 | 
				
			||||||
 | 
					        "L2": l2_distance,
 | 
				
			||||||
 | 
					        "COSINE": cosine_distance,
 | 
				
			||||||
 | 
					        "IP": ip_distance
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    distance = pairwise_distances(train_emb, Y=test_emb, metric=milvus_sklearn_metric_map[metric_type], n_jobs=-1)
 | 
				
			||||||
 | 
					    distance = np.array(distance.T, order='C', dtype=np.float16)
 | 
				
			||||||
 | 
					    distance_sorted = np.sort(distance, axis=1).tolist()
 | 
				
			||||||
 | 
					    return distance_sorted
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue