From dbfb4f20574f92362a7127a24fbced07d322b8f9 Mon Sep 17 00:00:00 2001 From: "xj.lin" Date: Mon, 25 Mar 2019 16:34:40 +0800 Subject: [PATCH] add list => array --- pyengine/engine/controller/scheduler.py | 21 +++++++------------ .../controller/tests/test_vector_engine.py | 5 ++++- pyengine/engine/controller/vector_engine.py | 5 ++--- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/pyengine/engine/controller/scheduler.py b/pyengine/engine/controller/scheduler.py index c88f734840..3eab1b49b3 100644 --- a/pyengine/engine/controller/scheduler.py +++ b/pyengine/engine/controller/scheduler.py @@ -1,7 +1,7 @@ from engine.retrieval import search_index from engine.ingestion import build_index from engine.ingestion import serialize -import numpy as np + class Singleton(type): _instances = {} @@ -15,36 +15,29 @@ class Scheduler(metaclass=Singleton): def Search(self, index_file_key, vectors, k): # assert index_file_key # assert vectors - # assert k + assert k != 0 - return self.__scheduler(index_file_key, vectors, k) + query_vectors = serialize.to_array(vectors) + + return self.__scheduler(index_file_key, query_vectors, k) def __scheduler(self, index_data_key, vectors, k): result_list = [] - d = None - raw_vectors = None - print("__scheduler: vectors: ", vectors) - query_vectors = np.asarray(vectors).astype('float32') - if 'raw' in index_data_key: raw_vectors = index_data_key['raw'] - raw_vectors = np.asarray(raw_vectors).astype('float32') d = index_data_key['dimension'] - - if 'raw' in index_data_key: index_builder = build_index.FactoryIndex() - print("d: ", d, " raw_vectors: ", raw_vectors) index = index_builder().build(d, raw_vectors) searcher = search_index.FaissSearch(index) - result_list.append(searcher.search_by_vectors(query_vectors, k)) + result_list.append(searcher.search_by_vectors(vectors, k)) index_data_list = index_data_key['index'] for key in index_data_list: index = GetIndexData(key) searcher = search_index.FaissSearch(index) - result_list.append(searcher.search_by_vectors(query_vectors, k)) + result_list.append(searcher.search_by_vectors(vectors, k)) if len(result_list) == 1: return result_list[0].vectors diff --git a/pyengine/engine/controller/tests/test_vector_engine.py b/pyengine/engine/controller/tests/test_vector_engine.py index 0a7d193482..25eb5a9ed6 100644 --- a/pyengine/engine/controller/tests/test_vector_engine.py +++ b/pyengine/engine/controller/tests/test_vector_engine.py @@ -3,6 +3,7 @@ from engine.settings import DATABASE_DIRECTORY from flask import jsonify import pytest import os +import numpy as np import logging logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -104,10 +105,12 @@ class TestVectorEngine: expected_list = [self.__vector] vector_list = VectorEngine.GetVectorListFromRawFile('test_group', filename) + print('expected_list: ', expected_list) print('vector_list: ', vector_list) + expected_list = np.asarray(expected_list).astype('float32') - assert vector_list == expected_list + assert np.all(vector_list == expected_list) code = VectorEngine.ClearRawFile('test_group') assert code == VectorEngine.SUCCESS_CODE diff --git a/pyengine/engine/controller/vector_engine.py b/pyengine/engine/controller/vector_engine.py index 0d1e32f7b8..9a44b2c02b 100644 --- a/pyengine/engine/controller/vector_engine.py +++ b/pyengine/engine/controller/vector_engine.py @@ -144,7 +144,7 @@ class VectorEngine(object): index_keys = [ i.filename for i in files if i.type == 'index' ] index_map = {} index_map['index'] = index_keys - index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") + index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage index_map['dimension'] = group.dimension scheduler_instance = Scheduler() @@ -188,8 +188,7 @@ class VectorEngine(object): @staticmethod def GetVectorListFromRawFile(group_id, filename="todo"): - return VectorEngine.group_dict[group_id] - # return serialize.to_array(VectorEngine.group_dict[group_id]) + return serialize.to_array(VectorEngine.group_dict[group_id]) @staticmethod def ClearRawFile(group_id):