fixed formatting
parent
005be024f1
commit
b2bfd395ed
|
@ -27,6 +27,7 @@ except ImportError:
|
||||||
print("Weaviate not installed. Skipping import.")
|
print("Weaviate not installed. Skipping import.")
|
||||||
WeaviateMemory = None
|
WeaviateMemory = None
|
||||||
|
|
||||||
|
|
||||||
def get_memory(cfg, init=False):
|
def get_memory(cfg, init=False):
|
||||||
memory = None
|
memory = None
|
||||||
if cfg.memory_backend == "pinecone":
|
if cfg.memory_backend == "pinecone":
|
||||||
|
|
|
@ -7,6 +7,7 @@ from autogpt.config import AbstractSingleton, Config
|
||||||
|
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
|
|
||||||
def get_ada_embedding(text):
|
def get_ada_embedding(text):
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
if cfg.use_azure:
|
if cfg.use_azure:
|
||||||
|
|
|
@ -6,6 +6,7 @@ from weaviate import Client
|
||||||
from weaviate.embedded import EmbeddedOptions
|
from weaviate.embedded import EmbeddedOptions
|
||||||
from weaviate.util import generate_uuid5
|
from weaviate.util import generate_uuid5
|
||||||
|
|
||||||
|
|
||||||
def default_schema(weaviate_index):
|
def default_schema(weaviate_index):
|
||||||
return {
|
return {
|
||||||
"class": weaviate_index,
|
"class": weaviate_index,
|
||||||
|
@ -18,6 +19,7 @@ def default_schema(weaviate_index):
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class WeaviateMemory(MemoryProviderSingleton):
|
class WeaviateMemory(MemoryProviderSingleton):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
auth_credentials = self._build_auth_credentials(cfg)
|
auth_credentials = self._build_auth_credentials(cfg)
|
||||||
|
@ -72,7 +74,6 @@ class WeaviateMemory(MemoryProviderSingleton):
|
||||||
def get(self, data):
|
def get(self, data):
|
||||||
return self.get_relevant(data, 1)
|
return self.get_relevant(data, 1)
|
||||||
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.client.schema.delete_all()
|
self.client.schema.delete_all()
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from autogpt.config import Config
|
||||||
from autogpt.memory.weaviate import WeaviateMemory
|
from autogpt.memory.weaviate import WeaviateMemory
|
||||||
from autogpt.memory.base import get_ada_embedding
|
from autogpt.memory.base import get_ada_embedding
|
||||||
|
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {
|
@mock.patch.dict(os.environ, {
|
||||||
"WEAVIATE_HOST": "127.0.0.1",
|
"WEAVIATE_HOST": "127.0.0.1",
|
||||||
"WEAVIATE_PROTOCOL": "http",
|
"WEAVIATE_PROTOCOL": "http",
|
||||||
|
@ -83,7 +84,6 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||||
self.assertEqual(len(actual), 1)
|
self.assertEqual(len(actual), 1)
|
||||||
self.assertEqual(actual[0], doc)
|
self.assertEqual(actual[0], doc)
|
||||||
|
|
||||||
|
|
||||||
def test_get_stats(self):
|
def test_get_stats(self):
|
||||||
docs = [
|
docs = [
|
||||||
'You are now about to count the number of docs in this index',
|
'You are now about to count the number of docs in this index',
|
||||||
|
@ -98,7 +98,6 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||||
self.assertTrue('count' in stats)
|
self.assertTrue('count' in stats)
|
||||||
self.assertEqual(stats['count'], 2)
|
self.assertEqual(stats['count'], 2)
|
||||||
|
|
||||||
|
|
||||||
def test_clear(self):
|
def test_clear(self):
|
||||||
docs = [
|
docs = [
|
||||||
'Shame this is the last test for this class',
|
'Shame this is the last test for this class',
|
||||||
|
|
Loading…
Reference in New Issue