fixed formatting
parent
005be024f1
commit
b2bfd395ed
|
@ -27,6 +27,7 @@ except ImportError:
|
|||
print("Weaviate not installed. Skipping import.")
|
||||
WeaviateMemory = None
|
||||
|
||||
|
||||
def get_memory(cfg, init=False):
|
||||
memory = None
|
||||
if cfg.memory_backend == "pinecone":
|
||||
|
@ -53,7 +54,7 @@ def get_memory(cfg, init=False):
|
|||
" use Weaviate as a memory backend.")
|
||||
else:
|
||||
memory = WeaviateMemory(cfg)
|
||||
|
||||
|
||||
elif cfg.memory_backend == "no_memory":
|
||||
memory = NoMemory(cfg)
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from autogpt.config import AbstractSingleton, Config
|
|||
|
||||
cfg = Config()
|
||||
|
||||
|
||||
def get_ada_embedding(text):
|
||||
text = text.replace("\n", " ")
|
||||
if cfg.use_azure:
|
||||
|
|
|
@ -6,6 +6,7 @@ from weaviate import Client
|
|||
from weaviate.embedded import EmbeddedOptions
|
||||
from weaviate.util import generate_uuid5
|
||||
|
||||
|
||||
def default_schema(weaviate_index):
|
||||
return {
|
||||
"class": weaviate_index,
|
||||
|
@ -18,6 +19,7 @@ def default_schema(weaviate_index):
|
|||
],
|
||||
}
|
||||
|
||||
|
||||
class WeaviateMemory(MemoryProviderSingleton):
|
||||
def __init__(self, cfg):
|
||||
auth_credentials = self._build_auth_credentials(cfg)
|
||||
|
@ -72,12 +74,11 @@ class WeaviateMemory(MemoryProviderSingleton):
|
|||
def get(self, data):
|
||||
return self.get_relevant(data, 1)
|
||||
|
||||
|
||||
def clear(self):
|
||||
self.client.schema.delete_all()
|
||||
|
||||
# weaviate does not yet have a neat way to just remove the items in an index
|
||||
# without removing the entire schema, therefore we need to re-create it
|
||||
# without removing the entire schema, therefore we need to re-create it
|
||||
# after a call to delete_all
|
||||
self._create_schema()
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from autogpt.config import Config
|
|||
from autogpt.memory.weaviate import WeaviateMemory
|
||||
from autogpt.memory.base import get_ada_embedding
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {
|
||||
"WEAVIATE_HOST": "127.0.0.1",
|
||||
"WEAVIATE_PROTOCOL": "http",
|
||||
|
@ -38,13 +39,13 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||
))
|
||||
else:
|
||||
cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}")
|
||||
|
||||
|
||||
"""
|
||||
In order to run these tests you will need a local instance of
|
||||
Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose
|
||||
for creating local instances using docker.
|
||||
Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded):
|
||||
|
||||
|
||||
USE_WEAVIATE_EMBEDDED=True
|
||||
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
|
||||
"""
|
||||
|
@ -53,7 +54,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||
self.client.schema.delete_class(self.cfg.memory_index)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
self.memory = WeaviateMemory(self.cfg)
|
||||
|
||||
def test_add(self):
|
||||
|
@ -67,7 +68,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||
|
||||
def test_get(self):
|
||||
doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos'
|
||||
|
||||
|
||||
with self.client.batch as batch:
|
||||
batch.add_data_object(
|
||||
uuid=get_valid_uuid(uuid4()),
|
||||
|
@ -83,7 +84,6 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||
self.assertEqual(len(actual), 1)
|
||||
self.assertEqual(actual[0], doc)
|
||||
|
||||
|
||||
def test_get_stats(self):
|
||||
docs = [
|
||||
'You are now about to count the number of docs in this index',
|
||||
|
@ -98,7 +98,6 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||
self.assertTrue('count' in stats)
|
||||
self.assertEqual(stats['count'], 2)
|
||||
|
||||
|
||||
def test_clear(self):
|
||||
docs = [
|
||||
'Shame this is the last test for this class',
|
||||
|
|
Loading…
Reference in New Issue