added support for multiple memory provider and added weaviate integration
parent
c6d90227fe
commit
986d32ca42
|
@ -9,4 +9,10 @@ CUSTOM_SEARCH_ENGINE_ID=
|
|||
USE_AZURE=False
|
||||
OPENAI_API_BASE=your-base-url-for-azure
|
||||
OPENAI_API_VERSION=api-version-for-azure
|
||||
OPENAI_DEPLOYMENT_ID=deployment-id-for-azure
|
||||
OPENAI_DEPLOYMENT_ID=deployment-id-for-azure
|
||||
WEAVIATE_HOST="http://127.0.0.1"
|
||||
WEAVIATE_PORT="8080"
|
||||
WEAVIATE_USERNAME=
|
||||
WEAVIATE_PASSWORD=
|
||||
WEAVIATE_INDEX="Autogpt"
|
||||
MEMORY_PROVIDER="weaviate"
|
26
README.md
26
README.md
|
@ -140,7 +140,13 @@ export CUSTOM_SEARCH_ENGINE_ID="YOUR_CUSTOM_SEARCH_ENGINE_ID"
|
|||
|
||||
```
|
||||
|
||||
## 🌲 Pinecone API Key Setup
|
||||
## Vector based memory provider
|
||||
Auto-GPT supports two providers for vector-based memory, [Pinecone](https://www.pinecone.io/) and [Weaviate](https://weaviate.io/). To select the provider to use, specify the following in your `.env`:
|
||||
|
||||
```
|
||||
MEMORY_PROVIDER="pinecone" # change to "weaviate" to use weaviate as the memory provider
|
||||
```
|
||||
### 🌲 Pinecone API Key Setup
|
||||
|
||||
Pinecone enable a vector based memory so a vast memory can be stored and only relevant memories
|
||||
are loaded for the agent at any given time.
|
||||
|
@ -149,7 +155,7 @@ are loaded for the agent at any given time.
|
|||
2. Choose the `Starter` plan to avoid being charged.
|
||||
3. Find your API key and region under the default project in the left sidebar.
|
||||
|
||||
### Setting up environment variables
|
||||
#### Setting up environment variables
|
||||
For Windows Users:
|
||||
```
|
||||
setx PINECONE_API_KEY "YOUR_PINECONE_API_KEY"
|
||||
|
@ -165,6 +171,22 @@ export PINECONE_ENV="Your pinecone region" # something like: us-east4-gcp
|
|||
|
||||
Or you can set them in the `.env` file.
|
||||
|
||||
### Weaviate Setup
|
||||
|
||||
[Weaviate](https://weaviate.io/) is an open-source vector database. It allows to store data objects and vector embeddings from ML-models and scales seamlessly to billion of data objects. [An instance of Weaviate can be created locally (using Docker), on Kubernetes or using Weaviate Cloud Services](https://weaviate.io/developers/weaviate/quickstart).
|
||||
|
||||
#### Setting up enviornment variables
|
||||
|
||||
In your `.env` file set the following:
|
||||
|
||||
```
|
||||
WEAVIATE_HOST="http://127.0.0.1" # the URL of the running Weaviate instance
|
||||
WEAVIATE_PORT="8080"
|
||||
WEAVIATE_USERNAME="your username"
|
||||
WEAVIATE_PASSWORD="your password"
|
||||
WEAVIATE_INDEX="Autogpt" # name of the index to create for the application
|
||||
```
|
||||
|
||||
## View Memory Usage
|
||||
|
||||
1. View memory usage by using the `--debug` flag :)
|
||||
|
|
|
@ -12,3 +12,4 @@ docker
|
|||
duckduckgo-search
|
||||
google-api-python-client #(https://developers.google.com/custom-search/v1/overview)
|
||||
pinecone-client==2.2.1
|
||||
weaviate-client==3.15.4
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import browse
|
||||
import json
|
||||
from memory import PineconeMemory
|
||||
from factory import MemoryFactory
|
||||
import datetime
|
||||
import agent_manager as agents
|
||||
import speak
|
||||
|
@ -52,7 +52,7 @@ def get_command(response):
|
|||
|
||||
|
||||
def execute_command(command_name, arguments):
|
||||
memory = PineconeMemory()
|
||||
memory = MemoryFactory.get_memory(cfg.memory_provider)
|
||||
try:
|
||||
if command_name == "google":
|
||||
|
||||
|
|
|
@ -50,9 +50,17 @@ class Config(metaclass=Singleton):
|
|||
self.google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
self.custom_search_engine_id = os.getenv("CUSTOM_SEARCH_ENGINE_ID")
|
||||
|
||||
self.memory_provider = os.getenv("MEMORY_PROVIDER", 'pinecone')
|
||||
self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
||||
self.pinecone_region = os.getenv("PINECONE_ENV")
|
||||
|
||||
self.weaviate_host = os.getenv("WEAVIATE_HOST")
|
||||
self.weaviate_port = os.getenv("WEAVIATE_PORT")
|
||||
self.weaviate_username = os.getenv("WEAVIATE_USERNAME", None)
|
||||
self.weaviate_password = os.getenv("WEAVIATE_PASSWORD", None)
|
||||
self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None)
|
||||
self.weaviate_index = os.getenv("WEAVIATE_INDEX", 'auto-gpt')
|
||||
|
||||
# User agent headers to use when browsing web
|
||||
# Some websites might just completely deny request with an error code if no user agent was found.
|
||||
self.user_agent_header = {"User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"}
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from providers.pinecone import PineconeMemory
|
||||
from providers.weaviate import WeaviateMemory
|
||||
|
||||
class MemoryFactory:
|
||||
@staticmethod
|
||||
def get_memory(mem_type):
|
||||
if mem_type == 'pinecone':
|
||||
return PineconeMemory()
|
||||
|
||||
if mem_type == 'weaviate':
|
||||
return WeaviateMemory()
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import random
|
||||
import commands as cmd
|
||||
from memory import PineconeMemory
|
||||
from factory import MemoryFactory
|
||||
import data
|
||||
import chat
|
||||
from colorama import Fore, Style
|
||||
|
@ -283,7 +283,7 @@ user_input = "Determine which next command to use, and respond using the format
|
|||
|
||||
# Initialize memory and make sure it is empty.
|
||||
# this is particularly important for indexing and referencing pinecone memory
|
||||
memory = PineconeMemory()
|
||||
memory = MemoryFactory.get_memory(cfg.memory_provider)
|
||||
memory.clear()
|
||||
|
||||
print('Using memory of type: ' + memory.__class__.__name__)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
from config import Singleton
|
||||
import openai
|
||||
|
||||
def get_ada_embedding(text):
|
||||
text = text.replace("\n", " ")
|
||||
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
|
||||
|
||||
|
||||
def get_text_from_embedding(embedding):
|
||||
return openai.Embedding.retrieve(embedding, model="text-embedding-ada-002")["data"][0]["text"]
|
||||
|
||||
class Memory(metaclass=Singleton):
|
||||
def add(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def clear(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_relevant(self, data, num_relevant=5):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_stats(self):
|
||||
raise NotImplementedError()
|
|
@ -1,20 +1,10 @@
|
|||
from config import Config, Singleton
|
||||
from config import Config
|
||||
from providers.memory import Memory, get_ada_embedding
|
||||
import pinecone
|
||||
import openai
|
||||
|
||||
cfg = Config()
|
||||
|
||||
|
||||
def get_ada_embedding(text):
|
||||
text = text.replace("\n", " ")
|
||||
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
|
||||
|
||||
|
||||
def get_text_from_embedding(embedding):
|
||||
return openai.Embedding.retrieve(embedding, model="text-embedding-ada-002")["data"][0]["text"]
|
||||
|
||||
|
||||
class PineconeMemory(metaclass=Singleton):
|
||||
class PineconeMemory(Memory):
|
||||
def __init__(self):
|
||||
pinecone_api_key = cfg.pinecone_api_key
|
||||
pinecone_region = cfg.pinecone_region
|
||||
|
@ -58,4 +48,4 @@ class PineconeMemory(metaclass=Singleton):
|
|||
return [str(item['metadata']["raw_text"]) for item in sorted_results]
|
||||
|
||||
def get_stats(self):
|
||||
return self.index.describe_index_stats()
|
||||
return self.index.describe_index_stats()
|
|
@ -0,0 +1,100 @@
|
|||
from config import Config
|
||||
from providers.memory import Memory, get_ada_embedding
|
||||
from weaviate import Client
|
||||
import weaviate
|
||||
import uuid
|
||||
|
||||
from weaviate.util import generate_uuid5
|
||||
|
||||
cfg = Config()
|
||||
|
||||
SCHEMA = {
|
||||
"class": cfg.weaviate_index,
|
||||
"properties": [
|
||||
{
|
||||
"name": "raw_text",
|
||||
"dataType": ["text"],
|
||||
"description": "original text for the embedding"
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
class WeaviateMemory(Memory):
|
||||
|
||||
def __init__(self):
|
||||
auth_credentials = self._build_auth_credentials()
|
||||
|
||||
url = f'{cfg.weaviate_host}:{cfg.weaviate_port}'
|
||||
|
||||
self.client = Client(url, auth_client_secret=auth_credentials)
|
||||
|
||||
self._create_schema()
|
||||
|
||||
def _create_schema(self):
|
||||
if not self.client.schema.contains(SCHEMA):
|
||||
self.client.schema.create_class(SCHEMA)
|
||||
|
||||
@staticmethod
|
||||
def _build_auth_credentials():
|
||||
if cfg.weaviate_username and cfg.weaviate_password:
|
||||
return weaviate_auth.AuthClientPassword(cfg.weaviate_username, cfg.weaviate_password)
|
||||
else:
|
||||
return None
|
||||
|
||||
def add(self, data):
|
||||
vector = get_ada_embedding(data)
|
||||
|
||||
doc_uuid = generate_uuid5(data, cfg.weaviate_index)
|
||||
data_object = {
|
||||
'class': cfg.weaviate_index,
|
||||
'raw_text': data
|
||||
}
|
||||
|
||||
with self.client.batch as batch:
|
||||
batch.add_data_object(
|
||||
uuid=doc_uuid,
|
||||
data_object=data_object,
|
||||
class_name=cfg.weaviate_index,
|
||||
vector=vector
|
||||
)
|
||||
|
||||
batch.flush()
|
||||
|
||||
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
|
||||
|
||||
|
||||
def get(self, data):
|
||||
return self.get_relevant(data, 1)
|
||||
|
||||
|
||||
def clear(self):
|
||||
self.client.schema.delete_all()
|
||||
|
||||
# weaviate does not yet have a neat way to just remove the items in an index
|
||||
# without removing the entire schema, therefore we need to re-create it
|
||||
# after a call to delete_all
|
||||
self._create_schema()
|
||||
|
||||
return 'Obliterated'
|
||||
|
||||
def get_relevant(self, data, num_relevant=5):
|
||||
query_embedding = get_ada_embedding(data)
|
||||
try:
|
||||
results = self.client.query.get(cfg.weaviate_index, ['raw_text']) \
|
||||
.with_near_vector({'vector': query_embedding, 'certainty': 0.7}) \
|
||||
.with_limit(num_relevant) \
|
||||
.do()
|
||||
|
||||
print(results)
|
||||
|
||||
if len(results['data']['Get'][cfg.weaviate_index]) > 0:
|
||||
return [str(item['raw_text']) for item in results['data']['Get'][cfg.weaviate_index]]
|
||||
else:
|
||||
return []
|
||||
|
||||
except Exception as err:
|
||||
print(f'Unexpected error {err=}, {type(err)=}')
|
||||
return []
|
||||
|
||||
def get_stats(self):
|
||||
return self.client.index_stats.get(cfg.weaviate_index)
|
Loading…
Reference in New Issue