Format document

This commit is contained in:
Jan Mrna
2025-11-06 10:46:54 +01:00
parent 2fb7a7d224
commit e352780a3d
2 changed files with 220 additions and 156 deletions

136
db.py
View File

@@ -6,15 +6,16 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import pymupdf
import ollama # TODO split to another file
import ollama # TODO split to another file
#
# Types
#
type Vector = np.NDArray # np.NDArray[np.float32] ?
type Vector = np.NDArray # np.NDArray[np.float32] ?
type VectorBytes = bytes
@dataclass(slots=True)
class Record:
document_index: int
@@ -22,12 +23,14 @@ class Record:
text: str
chunk: int = 0 # Chunk number within the page (0-indexed)
@dataclass(slots=True)
class QueryResult:
record: Record
distance: float
document_name: str
@dataclass(slots=True)
class Database:
"""
@@ -36,41 +39,45 @@ class Database:
TODO For faster nearest neighbour lookup we should use something else,
e.g. kd-trees
"""
vectors: list[Vector]
records: dict[VectorBytes, Record]
documents: list[Path]
#
# Internal functions
#
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]:
def _find_nearest(
vectors_db: list[Vector], query_vector: Vector, count: int = 10
) -> list[tuple[float, int]]:
"""
Find the N nearest vectors to the query embedding.
Args:
vectors_db: List of vectors in the database
query_vector: Query embedding vector
n: Number of nearest neighbors to return
Returns:
List of (distance, index) tuples, sorted by distance (closest first)
"""
distances = [np.linalg.norm(x - query_vector) for x in vectors_db]
# Get indices sorted by distance
sorted_indices = np.argsort(distances)
# Return top N results as (distance, index) tuples
results = []
for i in range(min(count, len(sorted_indices))):
idx = sorted_indices[i]
results.append((float(distances[idx]), int(idx)))
return results
def _embed(text: str) -> Vector:
"""
Generate embedding vector for given text.
@@ -82,13 +89,15 @@ def _embed(text: str) -> Vector:
def _vectorize_record(record: Record) -> tuple[Record, Vector]:
return record, _embed(record.text)
#
# High-level (exported) functions
#
def create_dummy() -> Database:
db_length: Final[int] = 10
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
records = {
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
}
@@ -98,7 +107,7 @@ def create_dummy() -> Database:
def create_empty() -> Database:
"""
Creates a new empty database with no vectors or records.
Returns:
Empty Database object
"""
@@ -108,105 +117,109 @@ def create_empty() -> Database:
def load(database_file: Path) -> Database:
"""
Loads a database from the given file.
Args:
database_file: Path to the database file
Returns:
Database object loaded from file
"""
if not database_file.exists():
raise FileNotFoundError(f"Database file not found: {database_file}")
with open(database_file, 'rb') as f:
with open(database_file, "rb") as f:
serializable_db = pickle.load(f)
# Reconstruct vectors from bytes
vectors = []
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
vector_shape = serializable_db.get('vector_shape', ())
for vector_bytes in serializable_db['vectors']:
vector_dtype = np.dtype(serializable_db.get("vector_dtype", "float64"))
vector_shape = serializable_db.get("vector_shape", ())
for vector_bytes in serializable_db["vectors"]:
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
vectors.append(vector)
# Records already use bytes as keys, so we can use them directly
records = serializable_db['records']
documents = serializable_db['documents']
# Records already use bytes as keys, so we can use them directly
records = serializable_db["records"]
documents = serializable_db["documents"]
return Database(vectors, records, documents)
def save(db: Database, database_file: Path) -> None:
"""
Saves the database to a file using pickle serialization.
Args:
db: The Database object to save
database_file: Path where to save the database file
"""
# Ensure the directory exists
database_file.parent.mkdir(parents=True, exist_ok=True)
# Create a serializable version of the database
# Records already use bytes as keys, so we can use them directly
serializable_db = {
'vectors': [vector.tobytes() for vector in db.vectors],
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
'vector_shape': db.vectors[0].shape if db.vectors else (),
'records': db.records, # Already uses bytes as keys
'documents': db.documents,
"vectors": [vector.tobytes() for vector in db.vectors],
"vector_dtype": str(db.vectors[0].dtype) if db.vectors else "float64",
"vector_shape": db.vectors[0].shape if db.vectors else (),
"records": db.records, # Already uses bytes as keys
"documents": db.documents,
}
# Save to file
with open(database_file, 'wb') as f:
with open(database_file, "wb") as f:
pickle.dump(serializable_db, f)
def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]:
"""
Query the database and return the N nearest records.
Args:
db: Database object or path to database file
text: Query text to search for
record_count: Number of nearest neighbors to return (default: 10)
Returns:
List of (distance, Record) tuples, sorted by distance (closest first)
"""
if isinstance(db, Path):
db = load(db)
# Generate embedding for query text
query_vector = _embed(text)
# Find nearest vectors
# NOTE We're using euclidean distance as a metric of similarity,
# there are some alternatives (cos or dot product) which may be used.
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning)
nearest_results = _find_nearest(db.vectors, query_vector, record_count)
# Convert results to (distance, Record) tuples
results: list[QueryResult] = []
for distance, vector_idx in nearest_results:
# Get the vector at this index
vector = db.vectors[vector_idx]
vector_bytes = vector.tobytes()
# Look up the corresponding record
if vector_bytes in db.records:
record = db.records[vector_bytes]
results.append(QueryResult(record, distance, db.documents[record.document_index].name))
results.append(
QueryResult(record, distance, db.documents[record.document_index].name)
)
return results
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
"""
Adds a new document to the database. If path is given, do load, add, save.
Loads PDF with PyMuPDF, splits by pages, and creates records and vectors.
Uses multithreading for embedding generation.
Args:
db: Database object or path to database file
file: Path to PDF file to add
@@ -215,25 +228,25 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
"""
save_to_file = False
database_file_path = None
if isinstance(db, Path):
database_file_path = db
db = load(db)
save_to_file = True
if not file.exists():
raise FileNotFoundError(f"File not found: {file}")
if file.suffix.lower() != '.pdf':
if file.suffix.lower() != ".pdf":
raise ValueError(f"File must be a PDF: {file}")
print(f"Processing PDF: {file}")
document_index = len(db.documents)
try:
doc = pymupdf.open(file)
print(f"PDF opened successfully: {len(doc)} pages")
records: list[Record] = []
chunk_size = 1024
@@ -243,13 +256,15 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
if not text:
print(f" Page {page_num + 1}: Skipped (empty)")
continue
# Simple chunking - split text into chunks of specified size
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
chunk = text[i:i + chunk_size]
chunk = text[i : i + chunk_size]
if chunk_stripped := chunk.strip(): # Only add non-empty chunks
# page_num + 1 for use friendliness
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
records.append(
Record(document_index, page_num + 1, chunk_stripped, chunk_idx)
)
doc.close()
except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}")
@@ -261,34 +276,35 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
# TODO measure with GIL disabled to check if multithreading actually helps
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [ pool.submit(_vectorize_record, r) for r in records ]
futures = [pool.submit(_vectorize_record, r) for r in records]
for f in as_completed(futures):
record, vector = f.result()
db.records[vector.tobytes()] = record
db.vectors.append(vector)
print(f"Successfully processed {file}: {len(records)} chunks")
# Save database if we loaded it from file
if save_to_file and database_file_path:
save(db, database_file_path)
print(f"Database saved to {database_file_path}")
def get_document_path(db: Database | Path, document_index: int) -> Path:
"""
Get the file path of the document at the given index in the database.
Args:
db: Database object or path to database file
document_index: Index of the document to retrieve
Returns:
Path to the document file
"""
if isinstance(db, Path):
db = load(db)
if document_index < 0 or document_index >= len(db.documents):
raise IndexError(f"Document index out of range: {document_index}")
return db.documents[document_index]
return db.documents[document_index]