2023-04-07 05:08:25 +00:00
""" Redis memory provider. """
from typing import Any , List , Optional
import redis
from redis . commands . search . field import VectorField , TextField
from redis . commands . search . query import Query
from redis . commands . search . indexDefinition import IndexDefinition , IndexType
import numpy as np
from memory . base import MemoryProviderSingleton , get_ada_embedding
2023-04-12 20:13:34 +00:00
from logger import logger
2023-04-12 20:49:32 +00:00
from colorama import Fore , Style
2023-04-07 05:08:25 +00:00
SCHEMA = [
TextField ( " data " ) ,
VectorField (
" embedding " ,
" HNSW " ,
{
" TYPE " : " FLOAT32 " ,
" DIM " : 1536 ,
" DISTANCE_METRIC " : " COSINE "
}
) ,
]
class RedisMemory ( MemoryProviderSingleton ) :
def __init__ ( self , cfg ) :
"""
Initializes the Redis memory provider .
Args :
cfg : The config object .
Returns : None
"""
redis_host = cfg . redis_host
redis_port = cfg . redis_port
redis_password = cfg . redis_password
self . dimension = 1536
self . redis = redis . Redis (
host = redis_host ,
port = redis_port ,
password = redis_password ,
db = 0 # Cannot be changed
)
2023-04-07 20:27:48 +00:00
self . cfg = cfg
2023-04-12 20:13:34 +00:00
# Check redis connection
try :
self . redis . ping ( )
except redis . ConnectionError as e :
logger . typewriter_log ( " FAILED TO CONNECT TO REDIS " , Fore . RED , Style . BRIGHT + str ( e ) + Style . RESET_ALL )
2023-04-12 20:38:53 +00:00
logger . double_check ( " Please ensure you have setup and configured Redis properly for use. " +
f " You can check out { Fore . CYAN + Style . BRIGHT } https://github.com/Torantulino/Auto-GPT#redis-setup { Style . RESET_ALL } to ensure you ' ve set up everything correctly. " )
2023-04-12 20:13:34 +00:00
exit ( 1 )
2023-04-07 05:48:27 +00:00
if cfg . wipe_redis_on_start :
self . redis . flushall ( )
2023-04-07 05:08:25 +00:00
try :
2023-04-07 20:27:48 +00:00
self . redis . ft ( f " { cfg . memory_index } " ) . create_index (
2023-04-07 05:08:25 +00:00
fields = SCHEMA ,
definition = IndexDefinition (
2023-04-07 20:27:48 +00:00
prefix = [ f " { cfg . memory_index } : " ] ,
2023-04-07 05:08:25 +00:00
index_type = IndexType . HASH
)
)
except Exception as e :
print ( " Error creating Redis search index: " , e )
2023-04-07 20:27:48 +00:00
existing_vec_num = self . redis . get ( f ' { cfg . memory_index } -vec_num ' )
2023-04-07 05:48:27 +00:00
self . vec_num = int ( existing_vec_num . decode ( ' utf-8 ' ) ) if \
existing_vec_num else 0
2023-04-07 05:08:25 +00:00
def add ( self , data : str ) - > str :
"""
Adds a data point to the memory .
Args :
data : The data to add .
Returns : Message indicating that the data has been added .
"""
2023-04-09 04:33:18 +00:00
if ' Command Error: ' in data :
return " "
2023-04-07 05:08:25 +00:00
vector = get_ada_embedding ( data )
vector = np . array ( vector ) . astype ( np . float32 ) . tobytes ( )
data_dict = {
b " data " : data ,
" embedding " : vector
}
2023-04-07 05:48:27 +00:00
pipe = self . redis . pipeline ( )
2023-04-07 20:27:48 +00:00
pipe . hset ( f " { self . cfg . memory_index } : { self . vec_num } " , mapping = data_dict )
2023-04-07 05:08:25 +00:00
_text = f " Inserting data into memory at index: { self . vec_num } : \n " \
f " data: { data } "
self . vec_num + = 1
2023-04-07 20:27:48 +00:00
pipe . set ( f ' { self . cfg . memory_index } -vec_num ' , self . vec_num )
2023-04-07 05:48:27 +00:00
pipe . execute ( )
2023-04-07 05:08:25 +00:00
return _text
def get ( self , data : str ) - > Optional [ List [ Any ] ] :
"""
Gets the data from the memory that is most relevant to the given data .
Args :
data : The data to compare to .
Returns : The most relevant data .
"""
return self . get_relevant ( data , 1 )
def clear ( self ) - > str :
"""
Clears the redis server .
Returns : A message indicating that the memory has been cleared .
"""
self . redis . flushall ( )
return " Obliviated "
def get_relevant (
self ,
data : str ,
num_relevant : int = 5
) - > Optional [ List [ Any ] ] :
"""
Returns all the data in the memory that is relevant to the given data .
Args :
data : The data to compare to .
num_relevant : The number of relevant data to return .
Returns : A list of the most relevant data .
"""
query_embedding = get_ada_embedding ( data )
base_query = f " *=>[KNN { num_relevant } @embedding $vector AS vector_score] "
query = Query ( base_query ) . return_fields (
" data " ,
" vector_score "
) . sort_by ( " vector_score " ) . dialect ( 2 )
query_vector = np . array ( query_embedding ) . astype ( np . float32 ) . tobytes ( )
try :
2023-04-07 20:27:48 +00:00
results = self . redis . ft ( f " { self . cfg . memory_index } " ) . search (
2023-04-07 05:08:25 +00:00
query , query_params = { " vector " : query_vector }
)
except Exception as e :
print ( " Error calling Redis search: " , e )
return None
2023-04-09 04:33:18 +00:00
return [ result . data for result in results . docs ]
2023-04-07 05:08:25 +00:00
def get_stats ( self ) :
"""
Returns : The stats of the memory index .
"""
2023-04-07 20:27:48 +00:00
return self . redis . ft ( f " { self . cfg . memory_index } " ) . info ( )