fixed formatting

pull/424/head
cs0lar 2023-04-15 15:49:24 +01:00
parent 005be024f1
commit b2bfd395ed
4 changed files with 11 additions and 9 deletions

View File

@ -27,6 +27,7 @@ except ImportError:
print("Weaviate not installed. Skipping import.") print("Weaviate not installed. Skipping import.")
WeaviateMemory = None WeaviateMemory = None
def get_memory(cfg, init=False): def get_memory(cfg, init=False):
memory = None memory = None
if cfg.memory_backend == "pinecone": if cfg.memory_backend == "pinecone":
@ -53,7 +54,7 @@ def get_memory(cfg, init=False):
" use Weaviate as a memory backend.") " use Weaviate as a memory backend.")
else: else:
memory = WeaviateMemory(cfg) memory = WeaviateMemory(cfg)
elif cfg.memory_backend == "no_memory": elif cfg.memory_backend == "no_memory":
memory = NoMemory(cfg) memory = NoMemory(cfg)

View File

@ -7,6 +7,7 @@ from autogpt.config import AbstractSingleton, Config
cfg = Config() cfg = Config()
def get_ada_embedding(text): def get_ada_embedding(text):
text = text.replace("\n", " ") text = text.replace("\n", " ")
if cfg.use_azure: if cfg.use_azure:

View File

@ -6,6 +6,7 @@ from weaviate import Client
from weaviate.embedded import EmbeddedOptions from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5 from weaviate.util import generate_uuid5
def default_schema(weaviate_index): def default_schema(weaviate_index):
return { return {
"class": weaviate_index, "class": weaviate_index,
@ -18,6 +19,7 @@ def default_schema(weaviate_index):
], ],
} }
class WeaviateMemory(MemoryProviderSingleton): class WeaviateMemory(MemoryProviderSingleton):
def __init__(self, cfg): def __init__(self, cfg):
auth_credentials = self._build_auth_credentials(cfg) auth_credentials = self._build_auth_credentials(cfg)
@ -72,12 +74,11 @@ class WeaviateMemory(MemoryProviderSingleton):
def get(self, data): def get(self, data):
return self.get_relevant(data, 1) return self.get_relevant(data, 1)
def clear(self): def clear(self):
self.client.schema.delete_all() self.client.schema.delete_all()
# weaviate does not yet have a neat way to just remove the items in an index # weaviate does not yet have a neat way to just remove the items in an index
# without removing the entire schema, therefore we need to re-create it # without removing the entire schema, therefore we need to re-create it
# after a call to delete_all # after a call to delete_all
self._create_schema() self._create_schema()

View File

@ -11,6 +11,7 @@ from autogpt.config import Config
from autogpt.memory.weaviate import WeaviateMemory from autogpt.memory.weaviate import WeaviateMemory
from autogpt.memory.base import get_ada_embedding from autogpt.memory.base import get_ada_embedding
@mock.patch.dict(os.environ, { @mock.patch.dict(os.environ, {
"WEAVIATE_HOST": "127.0.0.1", "WEAVIATE_HOST": "127.0.0.1",
"WEAVIATE_PROTOCOL": "http", "WEAVIATE_PROTOCOL": "http",
@ -38,13 +39,13 @@ class TestWeaviateMemory(unittest.TestCase):
)) ))
else: else:
cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}") cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}")
""" """
In order to run these tests you will need a local instance of In order to run these tests you will need a local instance of
Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose
for creating local instances using docker. for creating local instances using docker.
Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded): Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded):
USE_WEAVIATE_EMBEDDED=True USE_WEAVIATE_EMBEDDED=True
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate" WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
""" """
@ -53,7 +54,7 @@ class TestWeaviateMemory(unittest.TestCase):
self.client.schema.delete_class(self.cfg.memory_index) self.client.schema.delete_class(self.cfg.memory_index)
except: except:
pass pass
self.memory = WeaviateMemory(self.cfg) self.memory = WeaviateMemory(self.cfg)
def test_add(self): def test_add(self):
@ -67,7 +68,7 @@ class TestWeaviateMemory(unittest.TestCase):
def test_get(self): def test_get(self):
doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos' doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos'
with self.client.batch as batch: with self.client.batch as batch:
batch.add_data_object( batch.add_data_object(
uuid=get_valid_uuid(uuid4()), uuid=get_valid_uuid(uuid4()),
@ -83,7 +84,6 @@ class TestWeaviateMemory(unittest.TestCase):
self.assertEqual(len(actual), 1) self.assertEqual(len(actual), 1)
self.assertEqual(actual[0], doc) self.assertEqual(actual[0], doc)
def test_get_stats(self): def test_get_stats(self):
docs = [ docs = [
'You are now about to count the number of docs in this index', 'You are now about to count the number of docs in this index',
@ -98,7 +98,6 @@ class TestWeaviateMemory(unittest.TestCase):
self.assertTrue('count' in stats) self.assertTrue('count' in stats)
self.assertEqual(stats['count'], 2) self.assertEqual(stats['count'], 2)
def test_clear(self): def test_clear(self):
docs = [ docs = [
'Shame this is the last test for this class', 'Shame this is the last test for this class',