From 986d32ca42671b1e092fa2842196463ac0992915 Mon Sep 17 00:00:00 2001 From: cs0lar Date: Fri, 7 Apr 2023 20:41:07 +0100 Subject: [PATCH] added support for multiple memory provider and added weaviate integration --- .env.template | 8 +- README.md | 26 ++++- requirements.txt | 1 + scripts/commands.py | 4 +- scripts/config.py | 8 ++ scripts/factory.py | 11 ++ scripts/main.py | 4 +- scripts/providers/__init__.py | 0 scripts/providers/memory.py | 26 +++++ scripts/{memory.py => providers/pinecone.py} | 18 +--- scripts/providers/weaviate.py | 100 +++++++++++++++++++ 11 files changed, 185 insertions(+), 21 deletions(-) create mode 100644 scripts/factory.py create mode 100644 scripts/providers/__init__.py create mode 100644 scripts/providers/memory.py rename scripts/{memory.py => providers/pinecone.py} (80%) create mode 100644 scripts/providers/weaviate.py diff --git a/.env.template b/.env.template index e9ccda5ed..c9a45b2b2 100644 --- a/.env.template +++ b/.env.template @@ -9,4 +9,10 @@ CUSTOM_SEARCH_ENGINE_ID= USE_AZURE=False OPENAI_API_BASE=your-base-url-for-azure OPENAI_API_VERSION=api-version-for-azure -OPENAI_DEPLOYMENT_ID=deployment-id-for-azure \ No newline at end of file +OPENAI_DEPLOYMENT_ID=deployment-id-for-azure +WEAVIATE_HOST="http://127.0.0.1" +WEAVIATE_PORT="8080" +WEAVIATE_USERNAME= +WEAVIATE_PASSWORD= +WEAVIATE_INDEX="Autogpt" +MEMORY_PROVIDER="weaviate" \ No newline at end of file diff --git a/README.md b/README.md index a89c5d03b..9e8c24f2f 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,13 @@ export CUSTOM_SEARCH_ENGINE_ID="YOUR_CUSTOM_SEARCH_ENGINE_ID" ``` -## 🌲 Pinecone API Key Setup +## Vector based memory provider +Auto-GPT supports two providers for vector-based memory, [Pinecone](https://www.pinecone.io/) and [Weaviate](https://weaviate.io/). To select the provider to use, specify the following in your `.env`: + +``` +MEMORY_PROVIDER="pinecone" # change to "weaviate" to use weaviate as the memory provider +``` +### 🌲 Pinecone API Key Setup Pinecone enable a vector based memory so a vast memory can be stored and only relevant memories are loaded for the agent at any given time. @@ -149,7 +155,7 @@ are loaded for the agent at any given time. 2. Choose the `Starter` plan to avoid being charged. 3. Find your API key and region under the default project in the left sidebar. -### Setting up environment variables +#### Setting up environment variables For Windows Users: ``` setx PINECONE_API_KEY "YOUR_PINECONE_API_KEY" @@ -165,6 +171,22 @@ export PINECONE_ENV="Your pinecone region" # something like: us-east4-gcp Or you can set them in the `.env` file. +### Weaviate Setup + +[Weaviate](https://weaviate.io/) is an open-source vector database. It allows to store data objects and vector embeddings from ML-models and scales seamlessly to billion of data objects. [An instance of Weaviate can be created locally (using Docker), on Kubernetes or using Weaviate Cloud Services](https://weaviate.io/developers/weaviate/quickstart). + +#### Setting up enviornment variables + +In your `.env` file set the following: + +``` +WEAVIATE_HOST="http://127.0.0.1" # the URL of the running Weaviate instance +WEAVIATE_PORT="8080" +WEAVIATE_USERNAME="your username" +WEAVIATE_PASSWORD="your password" +WEAVIATE_INDEX="Autogpt" # name of the index to create for the application +``` + ## View Memory Usage 1. View memory usage by using the `--debug` flag :) diff --git a/requirements.txt b/requirements.txt index ce2470985..aed03226a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ docker duckduckgo-search google-api-python-client #(https://developers.google.com/custom-search/v1/overview) pinecone-client==2.2.1 +weaviate-client==3.15.4 diff --git a/scripts/commands.py b/scripts/commands.py index fc10d1d05..13037c343 100644 --- a/scripts/commands.py +++ b/scripts/commands.py @@ -1,6 +1,6 @@ import browse import json -from memory import PineconeMemory +from factory import MemoryFactory import datetime import agent_manager as agents import speak @@ -52,7 +52,7 @@ def get_command(response): def execute_command(command_name, arguments): - memory = PineconeMemory() + memory = MemoryFactory.get_memory(cfg.memory_provider) try: if command_name == "google": diff --git a/scripts/config.py b/scripts/config.py index fe48d2980..bc88bbb9d 100644 --- a/scripts/config.py +++ b/scripts/config.py @@ -50,9 +50,17 @@ class Config(metaclass=Singleton): self.google_api_key = os.getenv("GOOGLE_API_KEY") self.custom_search_engine_id = os.getenv("CUSTOM_SEARCH_ENGINE_ID") + self.memory_provider = os.getenv("MEMORY_PROVIDER", 'pinecone') self.pinecone_api_key = os.getenv("PINECONE_API_KEY") self.pinecone_region = os.getenv("PINECONE_ENV") + self.weaviate_host = os.getenv("WEAVIATE_HOST") + self.weaviate_port = os.getenv("WEAVIATE_PORT") + self.weaviate_username = os.getenv("WEAVIATE_USERNAME", None) + self.weaviate_password = os.getenv("WEAVIATE_PASSWORD", None) + self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None) + self.weaviate_index = os.getenv("WEAVIATE_INDEX", 'auto-gpt') + # User agent headers to use when browsing web # Some websites might just completely deny request with an error code if no user agent was found. self.user_agent_header = {"User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"} diff --git a/scripts/factory.py b/scripts/factory.py new file mode 100644 index 000000000..44901631c --- /dev/null +++ b/scripts/factory.py @@ -0,0 +1,11 @@ +from providers.pinecone import PineconeMemory +from providers.weaviate import WeaviateMemory + +class MemoryFactory: + @staticmethod + def get_memory(mem_type): + if mem_type == 'pinecone': + return PineconeMemory() + + if mem_type == 'weaviate': + return WeaviateMemory() \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py index a79fd553c..795c2ac4e 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,7 +1,7 @@ import json import random import commands as cmd -from memory import PineconeMemory +from factory import MemoryFactory import data import chat from colorama import Fore, Style @@ -283,7 +283,7 @@ user_input = "Determine which next command to use, and respond using the format # Initialize memory and make sure it is empty. # this is particularly important for indexing and referencing pinecone memory -memory = PineconeMemory() +memory = MemoryFactory.get_memory(cfg.memory_provider) memory.clear() print('Using memory of type: ' + memory.__class__.__name__) diff --git a/scripts/providers/__init__.py b/scripts/providers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/scripts/providers/memory.py b/scripts/providers/memory.py new file mode 100644 index 000000000..0440536e8 --- /dev/null +++ b/scripts/providers/memory.py @@ -0,0 +1,26 @@ +from config import Singleton +import openai + +def get_ada_embedding(text): + text = text.replace("\n", " ") + return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"] + + +def get_text_from_embedding(embedding): + return openai.Embedding.retrieve(embedding, model="text-embedding-ada-002")["data"][0]["text"] + +class Memory(metaclass=Singleton): + def add(self, data): + raise NotImplementedError() + + def get(self, data): + raise NotImplementedError() + + def clear(self): + raise NotImplementedError() + + def get_relevant(self, data, num_relevant=5): + raise NotImplementedError() + + def get_stats(self): + raise NotImplementedError() \ No newline at end of file diff --git a/scripts/memory.py b/scripts/providers/pinecone.py similarity index 80% rename from scripts/memory.py rename to scripts/providers/pinecone.py index 0d265a31d..971ef1869 100644 --- a/scripts/memory.py +++ b/scripts/providers/pinecone.py @@ -1,20 +1,10 @@ -from config import Config, Singleton +from config import Config +from providers.memory import Memory, get_ada_embedding import pinecone -import openai cfg = Config() - -def get_ada_embedding(text): - text = text.replace("\n", " ") - return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"] - - -def get_text_from_embedding(embedding): - return openai.Embedding.retrieve(embedding, model="text-embedding-ada-002")["data"][0]["text"] - - -class PineconeMemory(metaclass=Singleton): +class PineconeMemory(Memory): def __init__(self): pinecone_api_key = cfg.pinecone_api_key pinecone_region = cfg.pinecone_region @@ -58,4 +48,4 @@ class PineconeMemory(metaclass=Singleton): return [str(item['metadata']["raw_text"]) for item in sorted_results] def get_stats(self): - return self.index.describe_index_stats() + return self.index.describe_index_stats() \ No newline at end of file diff --git a/scripts/providers/weaviate.py b/scripts/providers/weaviate.py new file mode 100644 index 000000000..21718a033 --- /dev/null +++ b/scripts/providers/weaviate.py @@ -0,0 +1,100 @@ +from config import Config +from providers.memory import Memory, get_ada_embedding +from weaviate import Client +import weaviate +import uuid + +from weaviate.util import generate_uuid5 + +cfg = Config() + +SCHEMA = { + "class": cfg.weaviate_index, + "properties": [ + { + "name": "raw_text", + "dataType": ["text"], + "description": "original text for the embedding" + } + ], +} + +class WeaviateMemory(Memory): + + def __init__(self): + auth_credentials = self._build_auth_credentials() + + url = f'{cfg.weaviate_host}:{cfg.weaviate_port}' + + self.client = Client(url, auth_client_secret=auth_credentials) + + self._create_schema() + + def _create_schema(self): + if not self.client.schema.contains(SCHEMA): + self.client.schema.create_class(SCHEMA) + + @staticmethod + def _build_auth_credentials(): + if cfg.weaviate_username and cfg.weaviate_password: + return weaviate_auth.AuthClientPassword(cfg.weaviate_username, cfg.weaviate_password) + else: + return None + + def add(self, data): + vector = get_ada_embedding(data) + + doc_uuid = generate_uuid5(data, cfg.weaviate_index) + data_object = { + 'class': cfg.weaviate_index, + 'raw_text': data + } + + with self.client.batch as batch: + batch.add_data_object( + uuid=doc_uuid, + data_object=data_object, + class_name=cfg.weaviate_index, + vector=vector + ) + + batch.flush() + + return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" + + + def get(self, data): + return self.get_relevant(data, 1) + + + def clear(self): + self.client.schema.delete_all() + + # weaviate does not yet have a neat way to just remove the items in an index + # without removing the entire schema, therefore we need to re-create it + # after a call to delete_all + self._create_schema() + + return 'Obliterated' + + def get_relevant(self, data, num_relevant=5): + query_embedding = get_ada_embedding(data) + try: + results = self.client.query.get(cfg.weaviate_index, ['raw_text']) \ + .with_near_vector({'vector': query_embedding, 'certainty': 0.7}) \ + .with_limit(num_relevant) \ + .do() + + print(results) + + if len(results['data']['Get'][cfg.weaviate_index]) > 0: + return [str(item['raw_text']) for item in results['data']['Get'][cfg.weaviate_index]] + else: + return [] + + except Exception as err: + print(f'Unexpected error {err=}, {type(err)=}') + return [] + + def get_stats(self): + return self.client.index_stats.get(cfg.weaviate_index) \ No newline at end of file