Merge pull request #1635 from Imccccc/feature/embedding-with-retry

Embedding Improvement
pull/1702/head
BillSchumacher 2023-04-15 15:04:31 -05:00 committed by GitHub
commit 898b7eed8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 9 deletions

View File

@ -113,3 +113,35 @@ def create_chat_completion(
raise RuntimeError(f"Failed to get response after {num_retries} retries")
return response.choices[0].message["content"]
def create_embedding_with_ada(text) -> list:
"""Create a embedding with text-ada-002 using the OpenAI SDK"""
num_retries = 10
for attempt in range(num_retries):
backoff = 2 ** (attempt + 2)
try:
if CFG.use_azure:
return openai.Embedding.create(input=[text],
engine=CFG.get_azure_deployment_id_for_model("text-embedding-ada-002"),
)["data"][0]["embedding"]
else:
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")[
"data"
][0]["embedding"]
except RateLimitError:
pass
except APIError as e:
if e.http_status == 502:
pass
else:
raise
if attempt == num_retries - 1:
raise
if CFG.debug_mode:
print(
Fore.RED + "Error: ",
f"API Bad gateway. Waiting {backoff} seconds..." + Fore.RESET,
)
time.sleep(backoff)

View File

@ -5,7 +5,8 @@ from typing import Any, List, Optional, Tuple
import numpy as np
import orjson
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
from autogpt.memory.base import MemoryProviderSingleton
from autogpt.llm_utils import create_embedding_with_ada
EMBED_DIM = 1536
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS
@ -70,7 +71,7 @@ class LocalCache(MemoryProviderSingleton):
return ""
self.data.texts.append(text)
embedding = get_ada_embedding(text)
embedding = create_embedding_with_ada(text)
vector = np.array(embedding).astype(np.float32)
vector = vector[np.newaxis, :]
@ -118,7 +119,7 @@ class LocalCache(MemoryProviderSingleton):
Returns: List[str]
"""
embedding = get_ada_embedding(text)
embedding = create_embedding_with_ada(text)
scores = np.dot(self.data.embeddings, embedding)

View File

@ -2,7 +2,8 @@ import pinecone
from colorama import Fore, Style
from autogpt.logs import logger
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
from autogpt.memory.base import MemoryProviderSingleton
from autogpt.llm_utils import create_embedding_with_ada
class PineconeMemory(MemoryProviderSingleton):
@ -43,7 +44,7 @@ class PineconeMemory(MemoryProviderSingleton):
self.index = pinecone.Index(table_name)
def add(self, data):
vector = get_ada_embedding(data)
vector = create_embedding_with_ada(data)
# no metadata here. We may wish to change that long term.
self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
@ -63,7 +64,7 @@ class PineconeMemory(MemoryProviderSingleton):
:param data: The data to compare to.
:param num_relevant: The number of relevant data to return. Defaults to 5
"""
query_embedding = get_ada_embedding(data)
query_embedding = create_embedding_with_ada(data)
results = self.index.query(
query_embedding, top_k=num_relevant, include_metadata=True
)

View File

@ -9,7 +9,8 @@ from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
from autogpt.logs import logger
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
from autogpt.memory.base import MemoryProviderSingleton
from autogpt.llm_utils import create_embedding_with_ada
SCHEMA = [
TextField("data"),
@ -85,7 +86,7 @@ class RedisMemory(MemoryProviderSingleton):
"""
if "Command Error:" in data:
return ""
vector = get_ada_embedding(data)
vector = create_embedding_with_ada(data)
vector = np.array(vector).astype(np.float32).tobytes()
data_dict = {b"data": data, "embedding": vector}
pipe = self.redis.pipeline()
@ -127,7 +128,7 @@ class RedisMemory(MemoryProviderSingleton):
Returns: A list of the most relevant data.
"""
query_embedding = get_ada_embedding(data)
query_embedding = create_embedding_with_ada(data)
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
query = (
Query(base_query)