Merge pull request #1635 from Imccccc/feature/embedding-with-retry
Embedding Improvementpull/1702/head
commit
898b7eed8a
|
@ -113,3 +113,35 @@ def create_chat_completion(
|
|||
raise RuntimeError(f"Failed to get response after {num_retries} retries")
|
||||
|
||||
return response.choices[0].message["content"]
|
||||
|
||||
|
||||
def create_embedding_with_ada(text) -> list:
|
||||
"""Create a embedding with text-ada-002 using the OpenAI SDK"""
|
||||
num_retries = 10
|
||||
for attempt in range(num_retries):
|
||||
backoff = 2 ** (attempt + 2)
|
||||
try:
|
||||
if CFG.use_azure:
|
||||
return openai.Embedding.create(input=[text],
|
||||
engine=CFG.get_azure_deployment_id_for_model("text-embedding-ada-002"),
|
||||
)["data"][0]["embedding"]
|
||||
else:
|
||||
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")[
|
||||
"data"
|
||||
][0]["embedding"]
|
||||
except RateLimitError:
|
||||
pass
|
||||
except APIError as e:
|
||||
if e.http_status == 502:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
if attempt == num_retries - 1:
|
||||
raise
|
||||
if CFG.debug_mode:
|
||||
print(
|
||||
Fore.RED + "Error: ",
|
||||
f"API Bad gateway. Waiting {backoff} seconds..." + Fore.RESET,
|
||||
)
|
||||
time.sleep(backoff)
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@ from typing import Any, List, Optional, Tuple
|
|||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
from autogpt.memory.base import MemoryProviderSingleton
|
||||
from autogpt.llm_utils import create_embedding_with_ada
|
||||
|
||||
EMBED_DIM = 1536
|
||||
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS
|
||||
|
@ -70,7 +71,7 @@ class LocalCache(MemoryProviderSingleton):
|
|||
return ""
|
||||
self.data.texts.append(text)
|
||||
|
||||
embedding = get_ada_embedding(text)
|
||||
embedding = create_embedding_with_ada(text)
|
||||
|
||||
vector = np.array(embedding).astype(np.float32)
|
||||
vector = vector[np.newaxis, :]
|
||||
|
@ -118,7 +119,7 @@ class LocalCache(MemoryProviderSingleton):
|
|||
|
||||
Returns: List[str]
|
||||
"""
|
||||
embedding = get_ada_embedding(text)
|
||||
embedding = create_embedding_with_ada(text)
|
||||
|
||||
scores = np.dot(self.data.embeddings, embedding)
|
||||
|
||||
|
|
|
@ -2,7 +2,8 @@ import pinecone
|
|||
from colorama import Fore, Style
|
||||
|
||||
from autogpt.logs import logger
|
||||
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
from autogpt.memory.base import MemoryProviderSingleton
|
||||
from autogpt.llm_utils import create_embedding_with_ada
|
||||
|
||||
|
||||
class PineconeMemory(MemoryProviderSingleton):
|
||||
|
@ -43,7 +44,7 @@ class PineconeMemory(MemoryProviderSingleton):
|
|||
self.index = pinecone.Index(table_name)
|
||||
|
||||
def add(self, data):
|
||||
vector = get_ada_embedding(data)
|
||||
vector = create_embedding_with_ada(data)
|
||||
# no metadata here. We may wish to change that long term.
|
||||
self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
|
||||
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
|
||||
|
@ -63,7 +64,7 @@ class PineconeMemory(MemoryProviderSingleton):
|
|||
:param data: The data to compare to.
|
||||
:param num_relevant: The number of relevant data to return. Defaults to 5
|
||||
"""
|
||||
query_embedding = get_ada_embedding(data)
|
||||
query_embedding = create_embedding_with_ada(data)
|
||||
results = self.index.query(
|
||||
query_embedding, top_k=num_relevant, include_metadata=True
|
||||
)
|
||||
|
|
|
@ -9,7 +9,8 @@ from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
|||
from redis.commands.search.query import Query
|
||||
|
||||
from autogpt.logs import logger
|
||||
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
from autogpt.memory.base import MemoryProviderSingleton
|
||||
from autogpt.llm_utils import create_embedding_with_ada
|
||||
|
||||
SCHEMA = [
|
||||
TextField("data"),
|
||||
|
@ -85,7 +86,7 @@ class RedisMemory(MemoryProviderSingleton):
|
|||
"""
|
||||
if "Command Error:" in data:
|
||||
return ""
|
||||
vector = get_ada_embedding(data)
|
||||
vector = create_embedding_with_ada(data)
|
||||
vector = np.array(vector).astype(np.float32).tobytes()
|
||||
data_dict = {b"data": data, "embedding": vector}
|
||||
pipe = self.redis.pipeline()
|
||||
|
@ -127,7 +128,7 @@ class RedisMemory(MemoryProviderSingleton):
|
|||
|
||||
Returns: A list of the most relevant data.
|
||||
"""
|
||||
query_embedding = get_ada_embedding(data)
|
||||
query_embedding = create_embedding_with_ada(data)
|
||||
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
|
||||
query = (
|
||||
Query(base_query)
|
||||
|
|
Loading…
Reference in New Issue