Merge pull request #913 from chozzz/bugfix-823

Bugfix for #840 - Local memory fix
pull/719/head^2
Richard Beales 2023-04-12 20:07:50 +01:00 committed by GitHub
commit 7e3ff66494
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 3 deletions

View File

@ -28,10 +28,20 @@ class LocalCache(MemoryProviderSingleton):
def __init__(self, cfg) -> None:
self.filename = f"{cfg.memory_index}.json"
if os.path.exists(self.filename):
with open(self.filename, 'rb') as f:
loaded = orjson.loads(f.read())
self.data = CacheContent(**loaded)
try:
with open(self.filename, 'w+b') as f:
file_content = f.read()
if not file_content.strip():
file_content = b'{}'
f.write(file_content)
loaded = orjson.loads(file_content)
self.data = CacheContent(**loaded)
except orjson.JSONDecodeError:
print(f"Error: The file '{self.filename}' is not in JSON format.")
self.data = CacheContent()
else:
print(f"Warning: The file '{self.filename}' does not exist. Local memory would not be saved to a file.")
self.data = CacheContent()
def add(self, text: str):

51
tests/local_cache_test.py Normal file
View File

@ -0,0 +1,51 @@
import os
import sys
# Probably a better way:
sys.path.append(os.path.abspath('../scripts'))
from memory.local import LocalCache
def MockConfig():
return type('MockConfig', (object,), {
'debug_mode': False,
'continuous_mode': False,
'speak_mode': False,
'memory_index': 'auto-gpt',
})
class TestLocalCache(unittest.TestCase):
def setUp(self):
self.cfg = MockConfig()
self.cache = LocalCache(self.cfg)
def test_add(self):
text = "Sample text"
self.cache.add(text)
self.assertIn(text, self.cache.data.texts)
def test_clear(self):
self.cache.clear()
self.assertEqual(self.cache.data, [""])
def test_get(self):
text = "Sample text"
self.cache.add(text)
result = self.cache.get(text)
self.assertEqual(result, [text])
def test_get_relevant(self):
text1 = "Sample text 1"
text2 = "Sample text 2"
self.cache.add(text1)
self.cache.add(text2)
result = self.cache.get_relevant(text1, 1)
self.assertEqual(result, [text1])
def test_get_stats(self):
text = "Sample text"
self.cache.add(text)
stats = self.cache.get_stats()
self.assertEqual(stats, (1, self.cache.data.embeddings.shape))
if __name__ == '__main__':
unittest.main()