mirror of https://github.com/milvus-io/milvus.git
commit
b1389854f0
|
@ -1,5 +1,7 @@
|
|||
from engine.retrieval import search_index
|
||||
from engine.ingestion import build_index
|
||||
from engine.ingestion import serialize
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
_instances = {}
|
||||
|
@ -11,40 +13,40 @@ class Singleton(type):
|
|||
|
||||
class Scheduler(metaclass=Singleton):
|
||||
def Search(self, index_file_key, vectors, k):
|
||||
assert index_file_key
|
||||
assert vectors
|
||||
assert k
|
||||
# assert index_file_key
|
||||
# assert vectors
|
||||
assert k != 0
|
||||
|
||||
return self.__scheduler(index_file_key, vectors, k)
|
||||
query_vectors = serialize.to_array(vectors)
|
||||
return self.__scheduler(index_file_key, query_vectors, k)
|
||||
|
||||
|
||||
def __scheduler(self, index_data_key, vectors, k):
|
||||
result_list = []
|
||||
|
||||
raw_data_list = index_data_key['raw']
|
||||
index_data_list = index_data_key['index']
|
||||
|
||||
for key in raw_data_list:
|
||||
raw_data, d = self.GetRawData(key)
|
||||
if 'raw' in index_data_key:
|
||||
raw_vectors = index_data_key['raw']
|
||||
d = index_data_key['dimension']
|
||||
index_builder = build_index.FactoryIndex()
|
||||
index = index_builder().build(d, raw_data)
|
||||
searcher = search_index.FaissSearch(index) # silly
|
||||
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||
|
||||
for key in index_data_list:
|
||||
index = self.GetIndexData(key)
|
||||
index = index_builder().build(d, raw_vectors)
|
||||
searcher = search_index.FaissSearch(index)
|
||||
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||
|
||||
if 'index' in index_data_key:
|
||||
index_data_list = index_data_key['index']
|
||||
for key in index_data_list:
|
||||
index = GetIndexData(key)
|
||||
searcher = search_index.FaissSearch(index)
|
||||
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||
|
||||
if len(result_list) == 1:
|
||||
return result_list[0].vectors
|
||||
|
||||
result = search_index.top_k(sum(result_list), k)
|
||||
return result
|
||||
total_result = []
|
||||
|
||||
# result = search_index.top_k(result_list, k)
|
||||
return result_list
|
||||
|
||||
|
||||
def GetIndexData(self, key):
|
||||
pass
|
||||
|
||||
def GetRawData(self, key):
|
||||
pass
|
||||
def GetIndexData(key):
|
||||
return serialize.read_index(key)
|
|
@ -0,0 +1,60 @@
|
|||
from ..scheduler import *
|
||||
|
||||
import unittest
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestScheduler(unittest.TestCase):
|
||||
def test_schedule(self):
|
||||
d = 64
|
||||
nb = 10000
|
||||
nq = 100
|
||||
nt = 5000
|
||||
xt, xb, xq = get_dataset(d, nb, nt, nq)
|
||||
file_name = "/tmp/faiss/tempfile_1"
|
||||
|
||||
|
||||
index = faiss.IndexFlatL2(d)
|
||||
print(index.is_trained)
|
||||
index.add(xb)
|
||||
faiss.write_index(index, file_name)
|
||||
Dref, Iref = index.search(xq, 5)
|
||||
|
||||
index2 = faiss.read_index(file_name)
|
||||
|
||||
schuduler_instance = Scheduler()
|
||||
|
||||
# query args 1
|
||||
query_index = dict()
|
||||
query_index['index'] = [file_name]
|
||||
vectors = schuduler_instance.Search(query_index, vectors=xq, k=5)
|
||||
assert np.all(vectors == Iref)
|
||||
|
||||
# query args 2
|
||||
query_index = dict()
|
||||
query_index['raw'] = xt
|
||||
query_index['dimension'] = d
|
||||
query_index['index'] = [file_name]
|
||||
vectors = schuduler_instance.Search(query_index, vectors=xq, k=5)
|
||||
# print("success")
|
||||
|
||||
|
||||
def get_dataset(d, nb, nt, nq):
|
||||
"""A dataset that is not completely random but still challenging to
|
||||
index
|
||||
"""
|
||||
d1 = 10 # intrinsic dimension (more or less)
|
||||
n = nb + nt + nq
|
||||
rs = np.random.RandomState(1338)
|
||||
x = rs.normal(size=(n, d1))
|
||||
x = np.dot(x, rs.rand(d1, d))
|
||||
# now we have a d1-dim ellipsoid in d-dimensional space
|
||||
# higher factor (>4) -> higher frequency -> less linear
|
||||
x = x * (rs.rand(d) * 4 + 0.1)
|
||||
x = np.sin(x)
|
||||
x = x.astype('float32')
|
||||
return x[:nt], x[nt:-nq], x[-nq:]
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -3,6 +3,7 @@ from engine.settings import DATABASE_DIRECTORY
|
|||
from flask import jsonify
|
||||
import pytest
|
||||
import os
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
@ -11,7 +12,7 @@ logger = logging.getLogger(__name__)
|
|||
class TestVectorEngine:
|
||||
def setup_class(self):
|
||||
self.__vector = [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]
|
||||
self.__limit = 3
|
||||
self.__limit = 1
|
||||
|
||||
|
||||
def teardown_class(self):
|
||||
|
@ -39,6 +40,7 @@ class TestVectorEngine:
|
|||
# Check the group list
|
||||
code, group_list = VectorEngine.GetGroupList()
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
print("group_list: ", group_list)
|
||||
assert group_list == [{'group_name': 'test_group', 'file_number': 0}]
|
||||
|
||||
# Add Vector for not exist group
|
||||
|
@ -49,6 +51,18 @@ class TestVectorEngine:
|
|||
code = VectorEngine.AddVector('test_group', self.__vector)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
||||
# Add vector for exist group
|
||||
code = VectorEngine.AddVector('test_group', self.__vector)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
||||
# Add vector for exist group
|
||||
code = VectorEngine.AddVector('test_group', self.__vector)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
||||
# Add vector for exist group
|
||||
code = VectorEngine.AddVector('test_group', self.__vector)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
||||
# Check search vector interface
|
||||
code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
@ -89,10 +103,12 @@ class TestVectorEngine:
|
|||
expected_list = [self.__vector]
|
||||
vector_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
|
||||
|
||||
|
||||
print('expected_list: ', expected_list)
|
||||
print('vector_list: ', vector_list)
|
||||
expected_list = np.asarray(expected_list).astype('float32')
|
||||
|
||||
assert vector_list == expected_list
|
||||
assert np.all(vector_list == expected_list)
|
||||
|
||||
code = VectorEngine.ClearRawFile('test_group')
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
|
|
@ -11,6 +11,7 @@ logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(le
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestViews:
|
||||
HEADERS = {'Content-Type': 'application/json'}
|
||||
|
||||
def loads(self, resp):
|
||||
return json.loads(resp.data.decode())
|
||||
|
@ -18,51 +19,61 @@ class TestViews:
|
|||
def test_group(self, test_client):
|
||||
data = {"dimension": 10}
|
||||
|
||||
resp = test_client.get('/vector/group/6')
|
||||
resp = test_client.get('/vector/group/6', headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 1
|
||||
|
||||
resp = test_client.post('/vector/group/6', data=json.dumps(data))
|
||||
resp = test_client.post('/vector/group/6', data=json.dumps(data), headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
resp = test_client.get('/vector/group/6')
|
||||
resp = test_client.get('/vector/group/6', headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
# GroupList
|
||||
resp = test_client.get('/vector/group')
|
||||
resp = test_client.get('/vector/group', headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
assert self.loads(resp)['group_list'] == [{'file_number': 0, 'group_name': '6'}]
|
||||
|
||||
resp = test_client.delete('/vector/group/6')
|
||||
resp = test_client.delete('/vector/group/6', headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
|
||||
def test_vector(self, test_client):
|
||||
dimension = {"dimension": 10}
|
||||
resp = test_client.post('/vector/group/6', data=json.dumps(dimension))
|
||||
dimension = {"dimension": 8}
|
||||
resp = test_client.post('/vector/group/6', data=json.dumps(dimension), headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
vector = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]}
|
||||
resp = test_client.post('/vector/add/6', data=json.dumps(vector))
|
||||
resp = test_client.post('/vector/add/6', data=json.dumps(vector), headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
resp = test_client.post('/vector/index/6')
|
||||
vector = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]}
|
||||
resp = test_client.post('/vector/add/6', data=json.dumps(vector), headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
limit = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], "limit": 3}
|
||||
resp = test_client.get('/vector/search/6', data=json.dumps(limit))
|
||||
vector = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]}
|
||||
resp = test_client.post('/vector/add/6', data=json.dumps(vector), headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
resp = test_client.post('/vector/index/6', headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
limit = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], "limit": 1}
|
||||
resp = test_client.get('/vector/search/6', data=json.dumps(limit), headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
assert self.loads(resp)['vector_id'] == 0
|
||||
|
||||
resp = test_client.delete('/vector/group/6')
|
||||
resp = test_client.delete('/vector/group/6', headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from flask import jsonify
|
|||
from engine import db
|
||||
from engine.ingestion import build_index
|
||||
from engine.controller.scheduler import Scheduler
|
||||
from engine.ingestion import serialize
|
||||
import sys, os
|
||||
|
||||
class VectorEngine(object):
|
||||
|
@ -93,7 +94,7 @@ class VectorEngine(object):
|
|||
|
||||
# check if the file can be indexed
|
||||
if file.row_number + 1 >= ROW_LIMIT:
|
||||
raw_data = GetVectorListFromRawFile(group_id)
|
||||
raw_data = VectorEngine.GetVectorListFromRawFile(group_id)
|
||||
d = group.dimension
|
||||
|
||||
# create index
|
||||
|
@ -102,9 +103,11 @@ class VectorEngine(object):
|
|||
|
||||
# TODO(jinhai): store index into Cache
|
||||
index_filename = file.filename + '_index'
|
||||
serialize.write_index(file_name=index_filename, index=index)
|
||||
|
||||
# TODO(jinhai): Update raw_file_name => index_file_name
|
||||
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1, 'type': 'index'})
|
||||
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
|
||||
'type': 'index',
|
||||
'filename': index_filename})
|
||||
pass
|
||||
|
||||
else:
|
||||
|
@ -134,16 +137,20 @@ class VectorEngine(object):
|
|||
if code == VectorEngine.FAULT_CODE:
|
||||
return VectorEngine.GROUP_NOT_EXIST
|
||||
|
||||
group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
|
||||
|
||||
# find all files
|
||||
files = FileTable.query.filter(FileTable.group_name == group_id).all()
|
||||
raw_keys = [ i.filename for i in files if i.type == 'raw' ]
|
||||
index_keys = [ i.filename for i in files if i.type == 'index' ]
|
||||
index_map = {}
|
||||
index_map['raw'] = raw_keys
|
||||
index_map['index'] = index_keys # {raw:[key1, key2], index:[key3, key4]}
|
||||
index_map['index'] = index_keys
|
||||
index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage
|
||||
index_map['dimension'] = group.dimension
|
||||
|
||||
scheduler_instance = Scheduler()
|
||||
result = scheduler_instance.Search(index_map, vector, limit)
|
||||
vectors = []
|
||||
vectors.append(vector)
|
||||
result = scheduler_instance.Search(index_map, vectors, limit)
|
||||
|
||||
vector_id = 0
|
||||
|
||||
|
@ -183,7 +190,7 @@ class VectorEngine(object):
|
|||
|
||||
@staticmethod
|
||||
def GetVectorListFromRawFile(group_id, filename="todo"):
|
||||
return VectorEngine.group_dict[group_id]
|
||||
return serialize.to_array(VectorEngine.group_dict[group_id])
|
||||
|
||||
@staticmethod
|
||||
def ClearRawFile(group_id):
|
||||
|
|
|
@ -26,11 +26,12 @@ class VectorSearch(Resource):
|
|||
def __init__(self):
|
||||
self.__parser = reqparse.RequestParser()
|
||||
self.__parser.add_argument('vector', type=float, action='append', location=['json'])
|
||||
self.__parser.add_argument('limit', type=int, action='append', location=['json'])
|
||||
self.__parser.add_argument('limit', type=int, location=['json'])
|
||||
|
||||
def get(self, group_id):
|
||||
args = self.__parser.parse_args()
|
||||
print('vector: ', args['vector'])
|
||||
print('VectorSearch vector: ', args['vector'])
|
||||
print('limit: ', args['limit'])
|
||||
# go to search every thing
|
||||
code, vector_id = VectorEngine.SearchVector(group_id, args['vector'], args['limit'])
|
||||
return jsonify({'code': code, 'vector_id': vector_id})
|
||||
|
@ -50,7 +51,7 @@ class Group(Resource):
|
|||
def __init__(self):
|
||||
self.__parser = reqparse.RequestParser()
|
||||
self.__parser.add_argument('group_id', type=str)
|
||||
self.__parser.add_argument('dimension', type=int, action='append', location=['json'])
|
||||
self.__parser.add_argument('dimension', type=int, location=['json'])
|
||||
|
||||
def post(self, group_id):
|
||||
args = self.__parser.parse_args()
|
||||
|
|
|
@ -15,7 +15,7 @@ def FactoryIndex(index_name="DefaultIndex"):
|
|||
|
||||
|
||||
class Index():
|
||||
def build(d, vectors, DEVICE=INDEX_DEVICES.CPU):
|
||||
def build(self, d, vectors, DEVICE=INDEX_DEVICES.CPU):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
|
@ -35,7 +35,7 @@ class DefaultIndex(Index):
|
|||
# maybe need to specif parameters
|
||||
pass
|
||||
|
||||
def build(d, vectors, DEVICE=INDEX_DEVICES.CPU):
|
||||
def build(self, d, vectors, DEVICE=INDEX_DEVICES.CPU):
|
||||
index = faiss.IndexFlatL2(d) # trained
|
||||
index.add(vectors)
|
||||
return index
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
import faiss
|
||||
import numpy as np
|
||||
|
||||
def write_index(index, file_name):
|
||||
faiss.write_index(index, file_name)
|
||||
|
||||
def read_index(file_name):
|
||||
return faiss.read_index(file_name)
|
||||
|
||||
def to_array(vec):
|
||||
return np.asarray(vec).astype('float32')
|
|
@ -0,0 +1,88 @@
|
|||
from ..build_index import *
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
||||
class TestBuildIndex(unittest.TestCase):
|
||||
def test_factory_method(self):
|
||||
index_builder = FactoryIndex()
|
||||
index = index_builder()
|
||||
self.assertIsInstance(index, DefaultIndex)
|
||||
|
||||
def test_default_index(self):
|
||||
d = 64
|
||||
nb = 10000
|
||||
nq = 100
|
||||
_, xb, xq = get_dataset(d, nb, 500, nq)
|
||||
|
||||
# Expected result
|
||||
index = faiss.IndexFlatL2(d)
|
||||
index.add(xb)
|
||||
Dref, Iref = index.search(xq, 5)
|
||||
|
||||
builder = DefaultIndex()
|
||||
index2 = builder.build(d, xb)
|
||||
Dnew, Inew = index2.search(xq, 5)
|
||||
|
||||
assert np.all(Dnew == Dref) and np.all(Inew == Iref)
|
||||
|
||||
def test_increase(self):
|
||||
d = 64
|
||||
nb = 10000
|
||||
nq = 100
|
||||
nt = 500
|
||||
xt, xb, xq = get_dataset(d, nb, nt, nq)
|
||||
|
||||
index = faiss.IndexFlatL2(d)
|
||||
index.add(xb)
|
||||
|
||||
assert index.ntotal == nb
|
||||
|
||||
Index.increase(index, xt)
|
||||
assert index.ntotal == nb + nt
|
||||
|
||||
def test_serialize(self):
|
||||
d = 64
|
||||
nb = 10000
|
||||
nq = 100
|
||||
nt = 500
|
||||
xt, xb, xq = get_dataset(d, nb, nt, nq)
|
||||
|
||||
index = faiss.IndexFlatL2(d)
|
||||
index.add(xb)
|
||||
Dref, Iref = index.search(xq, 5)
|
||||
|
||||
ar_data = Index.serialize(index)
|
||||
|
||||
reader = faiss.VectorIOReader()
|
||||
faiss.copy_array_to_vector(ar_data, reader.data)
|
||||
index2 = faiss.read_index(reader)
|
||||
|
||||
Dnew, Inew = index2.search(xq, 5)
|
||||
|
||||
assert np.all(Dnew == Dref) and np.all(Inew == Iref)
|
||||
|
||||
|
||||
|
||||
def get_dataset(d, nb, nt, nq):
|
||||
"""A dataset that is not completely random but still challenging to
|
||||
index
|
||||
"""
|
||||
d1 = 10 # intrinsic dimension (more or less)
|
||||
n = nb + nt + nq
|
||||
rs = np.random.RandomState(1338)
|
||||
x = rs.normal(size=(n, d1))
|
||||
x = np.dot(x, rs.rand(d1, d))
|
||||
# now we have a d1-dim ellipsoid in d-dimensional space
|
||||
# higher factor (>4) -> higher frequency -> less linear
|
||||
x = x * (rs.rand(d) * 4 + 0.1)
|
||||
x = np.sin(x)
|
||||
x = x.astype('float32')
|
||||
return x[:nt], x[nt:-nq], x[-nq:]
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -12,7 +12,6 @@ class GroupTable(db.Model):
|
|||
self.group_name = group_name
|
||||
self.dimension = dimension
|
||||
self.file_number = 0
|
||||
self.dimension = 0
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -7,8 +7,9 @@ class SearchResult():
|
|||
self.vectors = I
|
||||
|
||||
def __add__(self, other):
|
||||
self.distance += other.distance
|
||||
self.vectors += other.vectors
|
||||
distance = self.distance + other.distance
|
||||
vectors = self.vectors + other.vectors
|
||||
return SearchResult(distance, vectors)
|
||||
|
||||
|
||||
class FaissSearch():
|
||||
|
@ -31,6 +32,7 @@ class FaissSearch():
|
|||
D, I = self.__index.search(vector_list, k)
|
||||
return SearchResult(D, I)
|
||||
|
||||
|
||||
import heapq
|
||||
def top_k(input, k):
|
||||
#sorted = heapq.nsmallest(k, input, key=input.key)
|
||||
pass
|
|
@ -74,7 +74,6 @@ def basic_test():
|
|||
index.add(xb) # add vectors to the index
|
||||
print(index.ntotal)
|
||||
#faiss.write_index(index, "/tmp/faiss/tempfile_1")
|
||||
|
||||
writer = faiss.VectorIOWriter()
|
||||
faiss.write_index(index, writer)
|
||||
ar_data = faiss.vector_to_array(writer.data)
|
||||
|
@ -101,4 +100,4 @@ def basic_test():
|
|||
# print(I[-5:]) # neighbors of the 5 last queries
|
||||
|
||||
if __name__ == '__main__':
|
||||
basic_test()
|
||||
basic_test()
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
from engine.controller import scheduler
|
||||
|
||||
# scheduler.Scheduler.Search()
|
|
@ -0,0 +1,48 @@
|
|||
from ..search_index import *
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
class TestSearchSingleThread(unittest.TestCase):
|
||||
def test_search_by_vectors(self):
|
||||
d = 64
|
||||
nb = 10000
|
||||
nq = 100
|
||||
_, xb, xq = get_dataset(d, nb, 500, nq)
|
||||
|
||||
index = faiss.IndexFlatL2(d)
|
||||
index.add(xb)
|
||||
|
||||
# expect result
|
||||
Dref, Iref = index.search(xq, 5)
|
||||
|
||||
searcher = FaissSearch(index)
|
||||
result = searcher.search_by_vectors(xq, 5)
|
||||
|
||||
assert np.all(result.distance == Dref) \
|
||||
and np.all(result.vectors == Iref)
|
||||
pass
|
||||
|
||||
def test_top_k(selfs):
|
||||
pass
|
||||
|
||||
|
||||
def get_dataset(d, nb, nt, nq):
|
||||
"""A dataset that is not completely random but still challenging to
|
||||
index
|
||||
"""
|
||||
d1 = 10 # intrinsic dimension (more or less)
|
||||
n = nb + nt + nq
|
||||
rs = np.random.RandomState(1338)
|
||||
x = rs.normal(size=(n, d1))
|
||||
x = np.dot(x, rs.rand(d1, d))
|
||||
# now we have a d1-dim ellipsoid in d-dimensional space
|
||||
# higher factor (>4) -> higher frequency -> less linear
|
||||
x = x * (rs.rand(d) * 4 + 0.1)
|
||||
x = np.sin(x)
|
||||
x = x.astype('float32')
|
||||
return x[:nt], x[nt:-nq], x[-nq:]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -6,4 +6,4 @@ SQLALCHEMY_TRACK_MODIFICATIONS = False
|
|||
SQLALCHEMY_DATABASE_URI = "mysql+pymysql://vecwise@127.0.0.1:3306/vecdata"
|
||||
|
||||
ROW_LIMIT = 10000000
|
||||
DATABASE_DIRECTORY = '/home/jinhai/disk0/vecwise/db'
|
||||
DATABASE_DIRECTORY = '/tmp'
|
Loading…
Reference in New Issue