import unittest from uuid import uuid4 from weaviate import Client from weaviate.util import get_valid_uuid from autogpt.config import Config from autogpt.llm import get_ada_embedding from autogpt.memory.weaviate import WeaviateMemory class TestWeaviateMemory(unittest.TestCase): cfg = None client = None index = None @classmethod def setUpClass(cls): """Set up the test environment for the WeaviateMemory tests.""" # only create the connection to weaviate once cls.cfg = Config() if cls.cfg.use_weaviate_embedded: from weaviate.embedded import EmbeddedOptions cls.client = Client( embedded_options=EmbeddedOptions( hostname=cls.cfg.weaviate_host, port=int(cls.cfg.weaviate_port), persistence_data_path=cls.cfg.weaviate_embedded_path, ) ) else: cls.client = Client( f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}" ) cls.index = WeaviateMemory.format_classname(cls.cfg.memory_index) """ In order to run these tests you will need a local instance of Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose for creating local instances using docker. Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded): USE_WEAVIATE_EMBEDDED=True WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate" """ def setUp(self): """Set up the test environment for the WeaviateMemory tests.""" try: self.client.schema.delete_class(self.index) except: pass self.memory = WeaviateMemory(self.cfg) def test_add(self): """Test adding a text to the cache""" doc = "You are a Titan name Thanos and you are looking for the Infinity Stones" self.memory.add(doc) result = self.client.query.get(self.index, ["raw_text"]).do() actual = result["data"]["Get"][self.index] self.assertEqual(len(actual), 1) self.assertEqual(actual[0]["raw_text"], doc) def test_get(self): """Test getting a text from the cache""" doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos" # add the document to the cache with self.client.batch as batch: batch.add_data_object( uuid=get_valid_uuid(uuid4()), data_object={"raw_text": doc}, class_name=self.index, vector=get_ada_embedding(doc), ) batch.flush() actual = self.memory.get(doc) self.assertEqual(len(actual), 1) self.assertEqual(actual[0], doc) def test_get_stats(self): """Test getting the stats of the cache""" docs = [ "You are now about to count the number of docs in this index", "And then you about to find out if you can count correctly", ] [self.memory.add(doc) for doc in docs] stats = self.memory.get_stats() self.assertTrue(stats) self.assertTrue("count" in stats) self.assertEqual(stats["count"], 2) def test_clear(self): """Test clearing the cache""" docs = [ "Shame this is the last test for this class", "Testing is fun when someone else is doing it", ] [self.memory.add(doc) for doc in docs] self.assertEqual(self.memory.get_stats()["count"], 2) self.memory.clear() self.assertEqual(self.memory.get_stats()["count"], 0) if __name__ == "__main__": unittest.main()