2023-04-07 05:08:25 +00:00
|
|
|
"""Base class for memory providers."""
|
2023-04-07 03:25:17 +00:00
|
|
|
import abc
|
2023-04-14 19:42:28 +00:00
|
|
|
|
2023-04-07 03:25:17 +00:00
|
|
|
import openai
|
|
|
|
|
2023-04-14 19:42:28 +00:00
|
|
|
from autogpt.config import AbstractSingleton, Config
|
|
|
|
|
2023-04-13 17:01:12 +00:00
|
|
|
cfg = Config()
|
2023-04-07 03:25:17 +00:00
|
|
|
|
2023-04-13 17:57:14 +00:00
|
|
|
|
2023-04-07 03:25:17 +00:00
|
|
|
def get_ada_embedding(text):
|
|
|
|
text = text.replace("\n", " ")
|
2023-04-11 11:45:37 +00:00
|
|
|
if cfg.use_azure:
|
2023-04-14 19:42:28 +00:00
|
|
|
return openai.Embedding.create(
|
|
|
|
input=[text],
|
|
|
|
engine=cfg.get_azure_deployment_id_for_model("text-embedding-ada-002"),
|
|
|
|
)["data"][0]["embedding"]
|
2023-04-11 11:45:37 +00:00
|
|
|
else:
|
2023-04-14 19:42:28 +00:00
|
|
|
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")[
|
|
|
|
"data"
|
|
|
|
][0]["embedding"]
|
2023-04-07 03:25:17 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MemoryProviderSingleton(AbstractSingleton):
|
|
|
|
@abc.abstractmethod
|
|
|
|
def add(self, data):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
def get(self, data):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
def clear(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
def get_relevant(self, data, num_relevant=5):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
def get_stats(self):
|
|
|
|
pass
|