116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
""" Milvus memory storage provider."""
|
|
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections
|
|
|
|
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
|
|
|
|
|
|
class MilvusMemory(MemoryProviderSingleton):
|
|
"""Milvus memory storage provider."""
|
|
|
|
def __init__(self, cfg) -> None:
|
|
"""Construct a milvus memory storage connection.
|
|
|
|
Args:
|
|
cfg (Config): Auto-GPT global config.
|
|
"""
|
|
# connect to milvus server.
|
|
connections.connect(address=cfg.milvus_addr)
|
|
fields = [
|
|
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
|
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=1536),
|
|
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
|
|
]
|
|
|
|
# create collection if not exist and load it.
|
|
self.milvus_collection = cfg.milvus_collection
|
|
self.schema = CollectionSchema(fields, "auto-gpt memory storage")
|
|
self.collection = Collection(self.milvus_collection, self.schema)
|
|
# create index if not exist.
|
|
if not self.collection.has_index():
|
|
self.collection.release()
|
|
self.collection.create_index(
|
|
"embeddings",
|
|
{
|
|
"metric_type": "IP",
|
|
"index_type": "HNSW",
|
|
"params": {"M": 8, "efConstruction": 64},
|
|
},
|
|
index_name="embeddings",
|
|
)
|
|
self.collection.load()
|
|
|
|
def add(self, data) -> str:
|
|
"""Add an embedding of data into memory.
|
|
|
|
Args:
|
|
data (str): The raw text to construct embedding index.
|
|
|
|
Returns:
|
|
str: log.
|
|
"""
|
|
embedding = get_ada_embedding(data)
|
|
result = self.collection.insert([[embedding], [data]])
|
|
_text = (
|
|
"Inserting data into memory at primary key: "
|
|
f"{result.primary_keys[0]}:\n data: {data}"
|
|
)
|
|
return _text
|
|
|
|
def get(self, data):
|
|
"""Return the most relevant data in memory.
|
|
Args:
|
|
data: The data to compare to.
|
|
"""
|
|
return self.get_relevant(data, 1)
|
|
|
|
def clear(self) -> str:
|
|
"""Drop the index in memory.
|
|
|
|
Returns:
|
|
str: log.
|
|
"""
|
|
self.collection.drop()
|
|
self.collection = Collection(self.milvus_collection, self.schema)
|
|
self.collection.create_index(
|
|
"embeddings",
|
|
{
|
|
"metric_type": "IP",
|
|
"index_type": "HNSW",
|
|
"params": {"M": 8, "efConstruction": 64},
|
|
},
|
|
index_name="embeddings",
|
|
)
|
|
self.collection.load()
|
|
return "Obliviated"
|
|
|
|
def get_relevant(self, data: str, num_relevant: int = 5):
|
|
"""Return the top-k relevant data in memory.
|
|
Args:
|
|
data: The data to compare to.
|
|
num_relevant (int, optional): The max number of relevant data.
|
|
Defaults to 5.
|
|
|
|
Returns:
|
|
list: The top-k relevant data.
|
|
"""
|
|
# search the embedding and return the most relevant text.
|
|
embedding = get_ada_embedding(data)
|
|
search_params = {
|
|
"metrics_type": "IP",
|
|
"params": {"nprobe": 8},
|
|
}
|
|
result = self.collection.search(
|
|
[embedding],
|
|
"embeddings",
|
|
search_params,
|
|
num_relevant,
|
|
output_fields=["raw_text"],
|
|
)
|
|
return [item.entity.value_of_field("raw_text") for item in result[0]]
|
|
|
|
def get_stats(self) -> str:
|
|
"""
|
|
Returns: The stats of the milvus cache.
|
|
"""
|
|
return f"Entities num: {self.collection.num_entities}"
|