Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Document table for storing original loaded documents #867

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion memgpt/agent_store/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None
self.include = ["documents", "embeddings", "metadatas"]

# need to be converted to strings
self.uuid_fields = ["id", "user_id", "agent_id", "source_id"]
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]

def get_filters(self, filters: Optional[Dict] = {}):
# get all filters for query
Expand Down
35 changes: 33 additions & 2 deletions memgpt/agent_store/db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, DateTime
from sqlalchemy import func, or_, and_
Expand All @@ -22,7 +23,7 @@
from memgpt.config import MemGPTConfig
from memgpt.utils import printd
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.data_types import Record, Message, Passage, ToolCall
from memgpt.data_types import Record, Message, Passage, ToolCall, Document
from memgpt.metadata import MetadataStore

from datetime import datetime
Expand Down Expand Up @@ -234,7 +235,37 @@ def to_record(self):
"""Create database model for table_name"""
class_name = f"{table_name.capitalize()}Model" + dialect
return create_or_get_model(class_name, MessageModel, table_name)
elif table_type == TableType.DOCUMENTS:

class DocumentModel(Base):
"""Defines data model for storing Document objects"""

__abstract__ = True # this line is necessary

id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
text = Column(String)
data_source = Column(String)
metadata_ = Column(MutableJson)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True), server_default=func.now())

def __repr__(self):
return f"<Document(document_id='{self.id}', text='{self.text}')>"

def to_record(self):
return Document(
user_id=self.user_id,
text=self.text,
data_source=self.data_source,
id=self.id,
created_at=str(self.created_at),
metadata=self.metadata_,
)

"""Create database model for table_name"""
class_name = f"{table_name.capitalize()}Model" + dialect
return create_or_get_model(class_name, DocumentModel, table_name)
else:
raise ValueError(f"Table type {table_type} not implemented")

Expand Down Expand Up @@ -471,7 +502,7 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None
# get storage URI
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
raise ValueError(f"Table type {table_type} not implemented")
elif table_type == TableType.RECALL_MEMORY:
elif table_type == TableType.RECALL_MEMORY or table_type == TableType.DOCUMENTS:
# TODO: eventually implement URI option
self.path = self.config.recall_storage_path
if self.path is None:
Expand Down
4 changes: 2 additions & 2 deletions memgpt/agent_store/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, table_type: TableType, config: MemGPTConfig, user_id, agent_i
self.table_name = RECALL_TABLE_NAME
elif table_type == TableType.DOCUMENTS:
self.type = Document
self.table_name == DOCUMENT_TABLE_NAME
self.table_name = DOCUMENT_TABLE_NAME
elif table_type == TableType.PASSAGES:
self.type = Passage
self.table_name = PASSAGE_TABLE_NAME
Expand Down Expand Up @@ -86,7 +86,7 @@ def get_filters(self, filters: Optional[Dict] = {}):
def get_storage_connector(table_type: TableType, config: MemGPTConfig, user_id, agent_id=None):
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
storage_type = config.archival_storage_type
elif table_type == TableType.RECALL_MEMORY:
elif table_type == TableType.RECALL_MEMORY or table_type == TableType.DOCUMENTS:
storage_type = config.recall_storage_type
else:
raise ValueError(f"Table type {table_type} not implemented")
Expand Down
113 changes: 55 additions & 58 deletions memgpt/cli/cli_load.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please leave a comment on why you're doing doc.text[2:] for future reference?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do we not need to do the same thing for loading webpages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do this because the SimpleDirectoryReader adds two new lines in the chunks.

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

"""

from typing import List

import llama_index
from llama_index.vector_stores import VectorStoreQuery, SimpleVectorStore
from typing import List, Optional, Annotated
from tqdm import tqdm
import numpy as np
Expand All @@ -33,32 +37,6 @@
app = typer.Typer()


def insert_passages_into_source(passages: List[Passage], source_name: str, user_id: uuid.UUID, config: MemGPTConfig):
"""Insert a list of passages into a source by updating storage connectors and metadata store"""
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
orig_size = storage.size()

# insert metadata store
ms = MetadataStore(config)
source = ms.get_source(user_id=user_id, source_name=source_name)
if not source:
# create new
source = Source(user_id=user_id, name=source_name, created_at=get_local_time())
ms.create_source(source)

# make sure user_id is set for passages
for passage in passages:
# TODO: attach source IDs
# passage.source_id = source.id
passage.user_id = user_id
passage.data_source = source_name

# add and save all passages
storage.insert_many(passages)
assert orig_size + len(passages) == storage.size(), f"Expected {orig_size + len(passages)} passages, got {storage.size()}"
storage.save()


def insert_passages_into_source(passages: List[Passage], source_name: str, user_id: uuid.UUID, config: MemGPTConfig):
"""Insert a list of passages into a source by updating storage connectors and metadata store"""
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
Expand Down Expand Up @@ -92,12 +70,6 @@ def store_docs(name, docs, user_id=None, show_progress=True):
if user_id is None: # assume running local with single user
user_id = uuid.UUID(config.anon_clientid)

# ensure doc text is not too long
# TODO: replace this to instead split up docs that are too large
# (this is a temporary fix to avoid breaking the llama index)
for doc in docs:
doc.text = check_and_split_text(doc.text, config.default_embedding_config.embedding_model)[0]

# record data source metadata
ms = MetadataStore(config)
user = ms.get_user(user_id)
Expand Down Expand Up @@ -135,39 +107,58 @@ def store_docs(name, docs, user_id=None, show_progress=True):
# compute and record passages
embed_model = embedding_model(config.default_embedding_config)

storage = StorageConnector.get_storage_connector(TableType.DOCUMENTS, config, user_id)
docs_storage = []
for doc in docs:
doc_storage = Document(user_id=user_id, text=doc.text, id=uuid.UUID(doc.doc_id), data_source=data_source.name)
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
docs_storage.append(doc_storage)

# use llama index to run embeddings code
with suppress_stdout():
service_context = ServiceContext.from_defaults(
llm=None, embed_model=embed_model, chunk_size=config.default_embedding_config.embedding_chunk_size
)
index = VectorStoreIndex.from_documents(docs, service_context=service_context, show_progress=True)
embed_dict = index._vector_store._data.embedding_dict
node_dict = index._docstore.docs

# TODO: add document store

# gather passages
doc_info = index.ref_doc_info
passages = []
for node_id, node in tqdm(node_dict.items()):
vector = embed_dict[node_id]
node.embedding = vector
text = node.text.replace("\x00", "\uFFFD") # hacky fix for error on null characters
assert (
len(node.embedding) == config.default_embedding_config.embedding_dim
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(node.embedding)}: {node.embedding}"
passages.append(
Passage(
user_id=user.id,
text=text,
data_source=name,
embedding=node.embedding,
metadata=None,
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
doc_dict = {}
passages_dict = index.docstore.docs

# SimpleVectorStore has a get method and takes in the node id to retrieve the embedding.
if isinstance(index.vector_store, SimpleVectorStore):
simple_store = index.vector_store
for doc_id, data in doc_info.items():
nodes_dict = []
for node_id in data.node_ids:
text_length = len(passages_dict[node_id].text)
assert text_length < config.default_embedding_config.embedding_chunk_size, "Chunk bigger than embedding chunk size!"
nodes_dict.append((node_id, simple_store.get(node_id), passages_dict[node_id].to_dict()))
doc_dict[doc_id] = nodes_dict

for doc_id, doc_data in tqdm(doc_dict.items()):
for data in doc_data:
node = data[2]
text = node["text"].replace("\x00", "\uFFFD") # hacky fix for error on null characters
assert (
len(data[1]) == config.default_embedding_config.embedding_dim
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(data[1])}: {data[1]}"
passages.append(
Passage(
id=uuid.UUID(data[0]),
user_id=user.id,
text=text,
data_source=name,
embedding=data[1],
metadata=None,
doc_id=uuid.UUID(doc_id),
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
)
)

storage.insert_many(docs_storage)
insert_passages_into_source(passages, name, user_id, config)
storage.save()


@app.command("index")
Expand Down Expand Up @@ -245,8 +236,14 @@ def load_directory(
reader = SimpleDirectoryReader(input_files=[str(f) for f in input_files])

# load docs
docs = reader.load_data()
store_docs(str(name), docs, user_id)
docs = []
for data in reader.iter_data():
# Remove the first two new lines added by SimpleDirectoryReader
doc = "".join([doc.text[2:] for doc in data])
doco = llama_index.Document()
doco.set_content(doc)
docs.append(doco)
store_docs(name, docs, user_id)

except ValueError as e:
typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED)
Expand All @@ -259,7 +256,7 @@ def load_webpage(
urls: Annotated[List[str], typer.Option(help="List of urls to load.")],
):
try:
from llama_index import SimpleWebPageReader
from llama_index.readers.web import SimpleWebPageReader

docs = SimpleWebPageReader(html_to_text=True).load_data(urls)
store_docs(name, docs)
Expand Down
17 changes: 14 additions & 3 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,24 @@ def to_openai_dict(self):
class Document(Record):
"""A document represent a document loaded into MemGPT, which is broken down into passages."""

def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None):
def __init__(
self,
user_id: str,
text: str,
data_source: str,
id: Optional[uuid.UUID] = None,
created_at: Optional[str] = None,
metadata: Optional[dict] = {},
):
super().__init__(id)
if metadata is None:
metadata = {}
self.user_id = user_id
self.text = text
self.document_id = document_id
self.data_source = data_source
# TODO: add optional embedding?
self.metadata = metadata
self.created_at = created_at
self.metadata = metadata

# def __repr__(self) -> str:
# pass
Expand Down
2 changes: 2 additions & 0 deletions tests/test_load_archival.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
assert [p.data_source == name for p in passages]
print("Passages", passages)

assert [passages[0].doc_id == passage.doc_id for passage in passages], "Expected all passages to have the same doc id!"

# test: listing sources
print("Querying all...")
sources = ms.list_sources(user_id=user_id)
Expand Down
Loading