mirror of https://github.com/milvus-io/milvus.git
add unittest for build/search index
parent
d37f89dc3d
commit
855d1c613d
|
@ -7,7 +7,9 @@ import unittest
|
||||||
|
|
||||||
class TestBuildIndex(unittest.TestCase):
|
class TestBuildIndex(unittest.TestCase):
|
||||||
def test_factory_method(self):
|
def test_factory_method(self):
|
||||||
pass
|
index_builder = FactoryIndex()
|
||||||
|
index = index_builder()
|
||||||
|
self.assertIsInstance(index, DefaultIndex)
|
||||||
|
|
||||||
def test_default_index(self):
|
def test_default_index(self):
|
||||||
d = 64
|
d = 64
|
||||||
|
@ -30,15 +32,38 @@ class TestBuildIndex(unittest.TestCase):
|
||||||
d = 64
|
d = 64
|
||||||
nb = 10000
|
nb = 10000
|
||||||
nq = 100
|
nq = 100
|
||||||
_, xb, xq = get_dataset(d, nb, 500, nq)
|
nt = 500
|
||||||
|
xt, xb, xq = get_dataset(d, nb, nt, nq)
|
||||||
|
|
||||||
index = faiss.IndexFlatL2(d)
|
index = faiss.IndexFlatL2(d)
|
||||||
index.add(xb)
|
index.add(xb)
|
||||||
|
|
||||||
pass
|
assert index.ntotal == nb
|
||||||
|
|
||||||
|
Index.increase(index, xt)
|
||||||
|
assert index.ntotal == nb + nt
|
||||||
|
|
||||||
def test_serialize(self):
|
def test_serialize(self):
|
||||||
pass
|
d = 64
|
||||||
|
nb = 10000
|
||||||
|
nq = 100
|
||||||
|
nt = 500
|
||||||
|
xt, xb, xq = get_dataset(d, nb, nt, nq)
|
||||||
|
|
||||||
|
index = faiss.IndexFlatL2(d)
|
||||||
|
index.add(xb)
|
||||||
|
Dref, Iref = index.search(xq, 5)
|
||||||
|
|
||||||
|
ar_data = Index.serialize(index)
|
||||||
|
|
||||||
|
reader = faiss.VectorIOReader()
|
||||||
|
faiss.copy_array_to_vector(ar_data, reader.data)
|
||||||
|
index2 = faiss.read_index(reader)
|
||||||
|
|
||||||
|
Dnew, Inew = index2.search(xq, 5)
|
||||||
|
|
||||||
|
assert np.all(Dnew == Dref) and np.all(Inew == Iref)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(d, nb, nt, nq):
|
def get_dataset(d, nb, nt, nq):
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
from engine.controller import scheduler
|
|
||||||
|
|
||||||
scheduler.Scheduler.Search()
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
from ..search_index import *
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class TestSearchSingleThread(unittest.TestCase):
|
||||||
|
def test_search_by_vectors(self):
|
||||||
|
d = 64
|
||||||
|
nb = 10000
|
||||||
|
nq = 100
|
||||||
|
_, xb, xq = get_dataset(d, nb, 500, nq)
|
||||||
|
|
||||||
|
index = faiss.IndexFlatL2(d)
|
||||||
|
index.add(xb)
|
||||||
|
|
||||||
|
# expect result
|
||||||
|
Dref, Iref = index.search(xq, 5)
|
||||||
|
|
||||||
|
searcher = FaissSearch(index)
|
||||||
|
result = searcher.search_by_vectors(xq, 5)
|
||||||
|
|
||||||
|
assert np.all(result.distance == Dref) \
|
||||||
|
and np.all(result.vectors == Iref)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_top_k(selfs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(d, nb, nt, nq):
|
||||||
|
"""A dataset that is not completely random but still challenging to
|
||||||
|
index
|
||||||
|
"""
|
||||||
|
d1 = 10 # intrinsic dimension (more or less)
|
||||||
|
n = nb + nt + nq
|
||||||
|
rs = np.random.RandomState(1338)
|
||||||
|
x = rs.normal(size=(n, d1))
|
||||||
|
x = np.dot(x, rs.rand(d1, d))
|
||||||
|
# now we have a d1-dim ellipsoid in d-dimensional space
|
||||||
|
# higher factor (>4) -> higher frequency -> less linear
|
||||||
|
x = x * (rs.rand(d) * 4 + 0.1)
|
||||||
|
x = np.sin(x)
|
||||||
|
x = x.astype('float32')
|
||||||
|
return x[:nt], x[nt:-nq], x[-nq:]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue