AutoGPT/scripts/memory/milvus.py

91 lines
3.0 KiB
Python

from pymilvus import (
connections,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
from memory.base import MemoryProviderSingleton, get_ada_embedding
class MilvusMemory(MemoryProviderSingleton):
def __init__(self, cfg):
""" Construct a milvus memory storage connection.
Args:
cfg (Config): Auto-GPT global config.
"""
# connect to milvus server.
connections.connect(address=cfg.milvus_addr)
fields = [
FieldSchema(name="pk", dtype=DataType.INT64,
is_primary=True, auto_id=True),
FieldSchema(name="embeddings",
dtype=DataType.FLOAT_VECTOR, dim=1536),
FieldSchema(name="raw_text", dtype=DataType.VARCHAR,
max_length=65535)
]
# create collection if not exist and load it.
schema = CollectionSchema(fields, "auto-gpt memory storage")
self.collection = Collection(cfg.milvus_collection, schema)
# create index if not exist.
if not self.collection.has_index(index_name="embeddings"):
self.collection.release()
self.collection.create_index("embeddings", {
"index_type": "IVF_FLAT",
"metric_type": "IP",
"params": {"nlist": 128},
}, index_name="embeddings")
self.collection.load()
def add(self, data):
""" Add a embedding of data into memory.
Args:
data (str): The raw text to construct embedding index.
Returns:
str: log.
"""
embedding = get_ada_embedding(data)
result = self.collection.insert([[embedding], [data]])
_text = f"Inserting data into memory at primary key: {result.primary_keys[0]}:\n data: {data}"
return _text
def get(self, data):
""" Return the most relevant data in memory.
Args:
data: The data to compare to.
"""
return self.get_relevant(data, 1)
def clear(self):
""" Drop the index in memory.
"""
self.collection.drop()
return "Obliviated"
def get_relevant(self, data, num_relevant=5):
""" Return the top-k relevant data in memory.
Args:
data: The data to compare to.
num_relevant (int, optional): The max number of relevant data. Defaults to 5.
"""
# search the embedding and return the most relevant text.
embedding = get_ada_embedding(data)
search_params = {
"metrics_type": "IP",
"params": {"nprobe": 8},
}
result = self.collection.search(
[embedding], "embeddings", search_params, num_relevant, output_fields=["raw_text"])
return [item.entity.value_of_field("raw_text") for item in result[0]]
def get_stats(self):
"""
Returns: The stats of the milvus cache.
"""
return f"Entities num: {self.collection.num_entities}"