mirror of https://github.com/milvus-io/milvus.git
add list => array
parent
5616ec74db
commit
dbfb4f2057
|
@ -1,7 +1,7 @@
|
||||||
from engine.retrieval import search_index
|
from engine.retrieval import search_index
|
||||||
from engine.ingestion import build_index
|
from engine.ingestion import build_index
|
||||||
from engine.ingestion import serialize
|
from engine.ingestion import serialize
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
class Singleton(type):
|
class Singleton(type):
|
||||||
_instances = {}
|
_instances = {}
|
||||||
|
@ -15,36 +15,29 @@ class Scheduler(metaclass=Singleton):
|
||||||
def Search(self, index_file_key, vectors, k):
|
def Search(self, index_file_key, vectors, k):
|
||||||
# assert index_file_key
|
# assert index_file_key
|
||||||
# assert vectors
|
# assert vectors
|
||||||
# assert k
|
assert k != 0
|
||||||
|
|
||||||
return self.__scheduler(index_file_key, vectors, k)
|
query_vectors = serialize.to_array(vectors)
|
||||||
|
|
||||||
|
return self.__scheduler(index_file_key, query_vectors, k)
|
||||||
|
|
||||||
|
|
||||||
def __scheduler(self, index_data_key, vectors, k):
|
def __scheduler(self, index_data_key, vectors, k):
|
||||||
result_list = []
|
result_list = []
|
||||||
|
|
||||||
d = None
|
|
||||||
raw_vectors = None
|
|
||||||
print("__scheduler: vectors: ", vectors)
|
|
||||||
query_vectors = np.asarray(vectors).astype('float32')
|
|
||||||
|
|
||||||
if 'raw' in index_data_key:
|
if 'raw' in index_data_key:
|
||||||
raw_vectors = index_data_key['raw']
|
raw_vectors = index_data_key['raw']
|
||||||
raw_vectors = np.asarray(raw_vectors).astype('float32')
|
|
||||||
d = index_data_key['dimension']
|
d = index_data_key['dimension']
|
||||||
|
|
||||||
if 'raw' in index_data_key:
|
|
||||||
index_builder = build_index.FactoryIndex()
|
index_builder = build_index.FactoryIndex()
|
||||||
print("d: ", d, " raw_vectors: ", raw_vectors)
|
|
||||||
index = index_builder().build(d, raw_vectors)
|
index = index_builder().build(d, raw_vectors)
|
||||||
searcher = search_index.FaissSearch(index)
|
searcher = search_index.FaissSearch(index)
|
||||||
result_list.append(searcher.search_by_vectors(query_vectors, k))
|
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||||
|
|
||||||
index_data_list = index_data_key['index']
|
index_data_list = index_data_key['index']
|
||||||
for key in index_data_list:
|
for key in index_data_list:
|
||||||
index = GetIndexData(key)
|
index = GetIndexData(key)
|
||||||
searcher = search_index.FaissSearch(index)
|
searcher = search_index.FaissSearch(index)
|
||||||
result_list.append(searcher.search_by_vectors(query_vectors, k))
|
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||||
|
|
||||||
if len(result_list) == 1:
|
if len(result_list) == 1:
|
||||||
return result_list[0].vectors
|
return result_list[0].vectors
|
||||||
|
|
|
@ -3,6 +3,7 @@ from engine.settings import DATABASE_DIRECTORY
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
@ -104,10 +105,12 @@ class TestVectorEngine:
|
||||||
expected_list = [self.__vector]
|
expected_list = [self.__vector]
|
||||||
vector_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
|
vector_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
|
||||||
|
|
||||||
|
|
||||||
print('expected_list: ', expected_list)
|
print('expected_list: ', expected_list)
|
||||||
print('vector_list: ', vector_list)
|
print('vector_list: ', vector_list)
|
||||||
|
expected_list = np.asarray(expected_list).astype('float32')
|
||||||
|
|
||||||
assert vector_list == expected_list
|
assert np.all(vector_list == expected_list)
|
||||||
|
|
||||||
code = VectorEngine.ClearRawFile('test_group')
|
code = VectorEngine.ClearRawFile('test_group')
|
||||||
assert code == VectorEngine.SUCCESS_CODE
|
assert code == VectorEngine.SUCCESS_CODE
|
||||||
|
|
|
@ -144,7 +144,7 @@ class VectorEngine(object):
|
||||||
index_keys = [ i.filename for i in files if i.type == 'index' ]
|
index_keys = [ i.filename for i in files if i.type == 'index' ]
|
||||||
index_map = {}
|
index_map = {}
|
||||||
index_map['index'] = index_keys
|
index_map['index'] = index_keys
|
||||||
index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename")
|
index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage
|
||||||
index_map['dimension'] = group.dimension
|
index_map['dimension'] = group.dimension
|
||||||
|
|
||||||
scheduler_instance = Scheduler()
|
scheduler_instance = Scheduler()
|
||||||
|
@ -188,8 +188,7 @@ class VectorEngine(object):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def GetVectorListFromRawFile(group_id, filename="todo"):
|
def GetVectorListFromRawFile(group_id, filename="todo"):
|
||||||
return VectorEngine.group_dict[group_id]
|
return serialize.to_array(VectorEngine.group_dict[group_id])
|
||||||
# return serialize.to_array(VectorEngine.group_dict[group_id])
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ClearRawFile(group_id):
|
def ClearRawFile(group_id):
|
||||||
|
|
Loading…
Reference in New Issue