Merge pull request #1635 from Imccccc/feature/embedding-with-retry
Embedding Improvementpull/1702/head
commit
898b7eed8a
|
@ -113,3 +113,35 @@ def create_chat_completion(
|
||||||
raise RuntimeError(f"Failed to get response after {num_retries} retries")
|
raise RuntimeError(f"Failed to get response after {num_retries} retries")
|
||||||
|
|
||||||
return response.choices[0].message["content"]
|
return response.choices[0].message["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_with_ada(text) -> list:
|
||||||
|
"""Create a embedding with text-ada-002 using the OpenAI SDK"""
|
||||||
|
num_retries = 10
|
||||||
|
for attempt in range(num_retries):
|
||||||
|
backoff = 2 ** (attempt + 2)
|
||||||
|
try:
|
||||||
|
if CFG.use_azure:
|
||||||
|
return openai.Embedding.create(input=[text],
|
||||||
|
engine=CFG.get_azure_deployment_id_for_model("text-embedding-ada-002"),
|
||||||
|
)["data"][0]["embedding"]
|
||||||
|
else:
|
||||||
|
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")[
|
||||||
|
"data"
|
||||||
|
][0]["embedding"]
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except APIError as e:
|
||||||
|
if e.http_status == 502:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
if attempt == num_retries - 1:
|
||||||
|
raise
|
||||||
|
if CFG.debug_mode:
|
||||||
|
print(
|
||||||
|
Fore.RED + "Error: ",
|
||||||
|
f"API Bad gateway. Waiting {backoff} seconds..." + Fore.RESET,
|
||||||
|
)
|
||||||
|
time.sleep(backoff)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,8 @@ from typing import Any, List, Optional, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
|
from autogpt.memory.base import MemoryProviderSingleton
|
||||||
|
from autogpt.llm_utils import create_embedding_with_ada
|
||||||
|
|
||||||
EMBED_DIM = 1536
|
EMBED_DIM = 1536
|
||||||
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS
|
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS
|
||||||
|
@ -70,7 +71,7 @@ class LocalCache(MemoryProviderSingleton):
|
||||||
return ""
|
return ""
|
||||||
self.data.texts.append(text)
|
self.data.texts.append(text)
|
||||||
|
|
||||||
embedding = get_ada_embedding(text)
|
embedding = create_embedding_with_ada(text)
|
||||||
|
|
||||||
vector = np.array(embedding).astype(np.float32)
|
vector = np.array(embedding).astype(np.float32)
|
||||||
vector = vector[np.newaxis, :]
|
vector = vector[np.newaxis, :]
|
||||||
|
@ -118,7 +119,7 @@ class LocalCache(MemoryProviderSingleton):
|
||||||
|
|
||||||
Returns: List[str]
|
Returns: List[str]
|
||||||
"""
|
"""
|
||||||
embedding = get_ada_embedding(text)
|
embedding = create_embedding_with_ada(text)
|
||||||
|
|
||||||
scores = np.dot(self.data.embeddings, embedding)
|
scores = np.dot(self.data.embeddings, embedding)
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,8 @@ import pinecone
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
|
|
||||||
from autogpt.logs import logger
|
from autogpt.logs import logger
|
||||||
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
|
from autogpt.memory.base import MemoryProviderSingleton
|
||||||
|
from autogpt.llm_utils import create_embedding_with_ada
|
||||||
|
|
||||||
|
|
||||||
class PineconeMemory(MemoryProviderSingleton):
|
class PineconeMemory(MemoryProviderSingleton):
|
||||||
|
@ -43,7 +44,7 @@ class PineconeMemory(MemoryProviderSingleton):
|
||||||
self.index = pinecone.Index(table_name)
|
self.index = pinecone.Index(table_name)
|
||||||
|
|
||||||
def add(self, data):
|
def add(self, data):
|
||||||
vector = get_ada_embedding(data)
|
vector = create_embedding_with_ada(data)
|
||||||
# no metadata here. We may wish to change that long term.
|
# no metadata here. We may wish to change that long term.
|
||||||
self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
|
self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
|
||||||
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
|
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
|
||||||
|
@ -63,7 +64,7 @@ class PineconeMemory(MemoryProviderSingleton):
|
||||||
:param data: The data to compare to.
|
:param data: The data to compare to.
|
||||||
:param num_relevant: The number of relevant data to return. Defaults to 5
|
:param num_relevant: The number of relevant data to return. Defaults to 5
|
||||||
"""
|
"""
|
||||||
query_embedding = get_ada_embedding(data)
|
query_embedding = create_embedding_with_ada(data)
|
||||||
results = self.index.query(
|
results = self.index.query(
|
||||||
query_embedding, top_k=num_relevant, include_metadata=True
|
query_embedding, top_k=num_relevant, include_metadata=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,8 @@ from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||||
from redis.commands.search.query import Query
|
from redis.commands.search.query import Query
|
||||||
|
|
||||||
from autogpt.logs import logger
|
from autogpt.logs import logger
|
||||||
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
|
from autogpt.memory.base import MemoryProviderSingleton
|
||||||
|
from autogpt.llm_utils import create_embedding_with_ada
|
||||||
|
|
||||||
SCHEMA = [
|
SCHEMA = [
|
||||||
TextField("data"),
|
TextField("data"),
|
||||||
|
@ -85,7 +86,7 @@ class RedisMemory(MemoryProviderSingleton):
|
||||||
"""
|
"""
|
||||||
if "Command Error:" in data:
|
if "Command Error:" in data:
|
||||||
return ""
|
return ""
|
||||||
vector = get_ada_embedding(data)
|
vector = create_embedding_with_ada(data)
|
||||||
vector = np.array(vector).astype(np.float32).tobytes()
|
vector = np.array(vector).astype(np.float32).tobytes()
|
||||||
data_dict = {b"data": data, "embedding": vector}
|
data_dict = {b"data": data, "embedding": vector}
|
||||||
pipe = self.redis.pipeline()
|
pipe = self.redis.pipeline()
|
||||||
|
@ -127,7 +128,7 @@ class RedisMemory(MemoryProviderSingleton):
|
||||||
|
|
||||||
Returns: A list of the most relevant data.
|
Returns: A list of the most relevant data.
|
||||||
"""
|
"""
|
||||||
query_embedding = get_ada_embedding(data)
|
query_embedding = create_embedding_with_ada(data)
|
||||||
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
|
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
|
||||||
query = (
|
query = (
|
||||||
Query(base_query)
|
Query(base_query)
|
||||||
|
|
Loading…
Reference in New Issue