111 lines
2.9 KiB
Python
111 lines
2.9 KiB
Python
# sourcery skip: snake-case-functions
|
|
"""Tests for LocalCache class"""
|
|
import unittest
|
|
|
|
import orjson
|
|
import pytest
|
|
|
|
from autogpt.memory.local import EMBED_DIM, SAVE_OPTIONS
|
|
from autogpt.memory.local import LocalCache as LocalCache_
|
|
from tests.utils import requires_api_key
|
|
|
|
|
|
@pytest.fixture
|
|
def LocalCache():
|
|
# Hack, real gross. Singletons are not good times.
|
|
if LocalCache_ in LocalCache_._instances:
|
|
del LocalCache_._instances[LocalCache_]
|
|
return LocalCache_
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_embed_with_ada(mocker):
|
|
mocker.patch(
|
|
"autogpt.memory.local.get_ada_embedding",
|
|
return_value=[0.1] * EMBED_DIM,
|
|
)
|
|
|
|
|
|
def test_init_without_backing_file(LocalCache, config, workspace):
|
|
cache_file = workspace.root / f"{config.memory_index}.json"
|
|
|
|
assert not cache_file.exists()
|
|
LocalCache(config)
|
|
assert cache_file.exists()
|
|
assert cache_file.read_text() == "{}"
|
|
|
|
|
|
def test_init_with_backing_empty_file(LocalCache, config, workspace):
|
|
cache_file = workspace.root / f"{config.memory_index}.json"
|
|
cache_file.touch()
|
|
|
|
assert cache_file.exists()
|
|
LocalCache(config)
|
|
assert cache_file.exists()
|
|
assert cache_file.read_text() == "{}"
|
|
|
|
|
|
def test_init_with_backing_file(LocalCache, config, workspace):
|
|
cache_file = workspace.root / f"{config.memory_index}.json"
|
|
cache_file.touch()
|
|
|
|
raw_data = {"texts": ["test"]}
|
|
data = orjson.dumps(raw_data, option=SAVE_OPTIONS)
|
|
with cache_file.open("wb") as f:
|
|
f.write(data)
|
|
|
|
assert cache_file.exists()
|
|
LocalCache(config)
|
|
assert cache_file.exists()
|
|
assert cache_file.read_text() == "{}"
|
|
|
|
|
|
def test_add(LocalCache, config, mock_embed_with_ada):
|
|
cache = LocalCache(config)
|
|
cache.add("test")
|
|
assert cache.data.texts == ["test"]
|
|
assert cache.data.embeddings.shape == (1, EMBED_DIM)
|
|
|
|
|
|
def test_clear(LocalCache, config, mock_embed_with_ada):
|
|
cache = LocalCache(config)
|
|
assert cache.data.texts == []
|
|
assert cache.data.embeddings.shape == (0, EMBED_DIM)
|
|
|
|
cache.add("test")
|
|
assert cache.data.texts == ["test"]
|
|
assert cache.data.embeddings.shape == (1, EMBED_DIM)
|
|
|
|
cache.clear()
|
|
assert cache.data.texts == []
|
|
assert cache.data.embeddings.shape == (0, EMBED_DIM)
|
|
|
|
|
|
def test_get(LocalCache, config, mock_embed_with_ada):
|
|
cache = LocalCache(config)
|
|
assert cache.get("test") == []
|
|
|
|
cache.add("test")
|
|
assert cache.get("test") == ["test"]
|
|
|
|
|
|
@pytest.mark.vcr
|
|
@requires_api_key("OPENAI_API_KEY")
|
|
def test_get_relevant(LocalCache, config) -> None:
|
|
cache = LocalCache(config)
|
|
text1 = "Sample text 1"
|
|
text2 = "Sample text 2"
|
|
cache.add(text1)
|
|
cache.add(text2)
|
|
|
|
result = cache.get_relevant(text1, 1)
|
|
assert result == [text1]
|
|
|
|
|
|
def test_get_stats(LocalCache, config, mock_embed_with_ada) -> None:
|
|
cache = LocalCache(config)
|
|
text = "Sample text"
|
|
cache.add(text)
|
|
stats = cache.get_stats()
|
|
assert stats == (1, cache.data.embeddings.shape)
|