diff --git a/scripts/memory/base.py b/scripts/memory/base.py index e3924d7e0..a0b3f25fc 100644 --- a/scripts/memory/base.py +++ b/scripts/memory/base.py @@ -3,6 +3,7 @@ import abc from config import AbstractSingleton, Config import openai +# try to import sentence transformers, if it fails, default to ada try: from sentence_transformers import SentenceTransformer except ImportError: @@ -18,6 +19,7 @@ cfg = Config() def get_embedding(text): text = text.replace("\n", " ") + # use the embeder specified in the config if cfg.memory_embeder == "sbert": embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu").encode(text, show_progress_bar=False) else: diff --git a/scripts/memory/local.py b/scripts/memory/local.py index 40a08f66b..728723cbd 100644 --- a/scripts/memory/local.py +++ b/scripts/memory/local.py @@ -6,8 +6,10 @@ import os from memory.base import MemoryProviderSingleton, get_embedding from config import Config +# TODO: get the embeddings dimension without importing config cfg = Config() +# set the embedding dimension based on the embeder EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768 SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS diff --git a/scripts/memory/pinecone.py b/scripts/memory/pinecone.py index b3aab33ab..e8a713162 100644 --- a/scripts/memory/pinecone.py +++ b/scripts/memory/pinecone.py @@ -10,6 +10,7 @@ class PineconeMemory(MemoryProviderSingleton): pinecone_api_key = cfg.pinecone_api_key pinecone_region = cfg.pinecone_region pinecone.init(api_key=pinecone_api_key, environment=pinecone_region) + # set the embedding dimension based on the embeder dimension = 1536 if cfg.memory_embeder == "ada" else 768 metric = "cosine" pod_type = "p1" diff --git a/scripts/memory/redismem.py b/scripts/memory/redismem.py index 8f3258357..3da528279 100644 --- a/scripts/memory/redismem.py +++ b/scripts/memory/redismem.py @@ -11,8 +11,10 @@ from logger import logger from colorama import Fore, Style from config import Config +# TODO: get the embeddings dimension without importing config cfg = Config() +# set the embedding dimension based on the embeder EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768 SCHEMA = [