Format document
This commit is contained in:
136
db.py
136
db.py
@@ -6,15 +6,16 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
import pymupdf
|
||||
import ollama # TODO split to another file
|
||||
import ollama # TODO split to another file
|
||||
|
||||
#
|
||||
# Types
|
||||
#
|
||||
|
||||
type Vector = np.NDArray # np.NDArray[np.float32] ?
|
||||
type Vector = np.NDArray # np.NDArray[np.float32] ?
|
||||
type VectorBytes = bytes
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Record:
|
||||
document_index: int
|
||||
@@ -22,12 +23,14 @@ class Record:
|
||||
text: str
|
||||
chunk: int = 0 # Chunk number within the page (0-indexed)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class QueryResult:
|
||||
record: Record
|
||||
distance: float
|
||||
document_name: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Database:
|
||||
"""
|
||||
@@ -36,41 +39,45 @@ class Database:
|
||||
TODO For faster nearest neighbour lookup we should use something else,
|
||||
e.g. kd-trees
|
||||
"""
|
||||
|
||||
vectors: list[Vector]
|
||||
records: dict[VectorBytes, Record]
|
||||
documents: list[Path]
|
||||
|
||||
|
||||
|
||||
#
|
||||
# Internal functions
|
||||
#
|
||||
|
||||
|
||||
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]:
|
||||
def _find_nearest(
|
||||
vectors_db: list[Vector], query_vector: Vector, count: int = 10
|
||||
) -> list[tuple[float, int]]:
|
||||
"""
|
||||
Find the N nearest vectors to the query embedding.
|
||||
|
||||
|
||||
Args:
|
||||
vectors_db: List of vectors in the database
|
||||
query_vector: Query embedding vector
|
||||
n: Number of nearest neighbors to return
|
||||
|
||||
|
||||
Returns:
|
||||
List of (distance, index) tuples, sorted by distance (closest first)
|
||||
"""
|
||||
distances = [np.linalg.norm(x - query_vector) for x in vectors_db]
|
||||
|
||||
|
||||
# Get indices sorted by distance
|
||||
sorted_indices = np.argsort(distances)
|
||||
|
||||
|
||||
# Return top N results as (distance, index) tuples
|
||||
results = []
|
||||
for i in range(min(count, len(sorted_indices))):
|
||||
idx = sorted_indices[i]
|
||||
results.append((float(distances[idx]), int(idx)))
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _embed(text: str) -> Vector:
|
||||
"""
|
||||
Generate embedding vector for given text.
|
||||
@@ -82,13 +89,15 @@ def _embed(text: str) -> Vector:
|
||||
def _vectorize_record(record: Record) -> tuple[Record, Vector]:
|
||||
return record, _embed(record.text)
|
||||
|
||||
|
||||
#
|
||||
# High-level (exported) functions
|
||||
#
|
||||
|
||||
|
||||
def create_dummy() -> Database:
|
||||
db_length: Final[int] = 10
|
||||
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
|
||||
vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
|
||||
records = {
|
||||
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
|
||||
}
|
||||
@@ -98,7 +107,7 @@ def create_dummy() -> Database:
|
||||
def create_empty() -> Database:
|
||||
"""
|
||||
Creates a new empty database with no vectors or records.
|
||||
|
||||
|
||||
Returns:
|
||||
Empty Database object
|
||||
"""
|
||||
@@ -108,105 +117,109 @@ def create_empty() -> Database:
|
||||
def load(database_file: Path) -> Database:
|
||||
"""
|
||||
Loads a database from the given file.
|
||||
|
||||
|
||||
Args:
|
||||
database_file: Path to the database file
|
||||
|
||||
|
||||
Returns:
|
||||
Database object loaded from file
|
||||
"""
|
||||
if not database_file.exists():
|
||||
raise FileNotFoundError(f"Database file not found: {database_file}")
|
||||
|
||||
with open(database_file, 'rb') as f:
|
||||
|
||||
with open(database_file, "rb") as f:
|
||||
serializable_db = pickle.load(f)
|
||||
|
||||
|
||||
# Reconstruct vectors from bytes
|
||||
vectors = []
|
||||
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
|
||||
vector_shape = serializable_db.get('vector_shape', ())
|
||||
|
||||
for vector_bytes in serializable_db['vectors']:
|
||||
vector_dtype = np.dtype(serializable_db.get("vector_dtype", "float64"))
|
||||
vector_shape = serializable_db.get("vector_shape", ())
|
||||
|
||||
for vector_bytes in serializable_db["vectors"]:
|
||||
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
|
||||
vectors.append(vector)
|
||||
|
||||
# Records already use bytes as keys, so we can use them directly
|
||||
records = serializable_db['records']
|
||||
|
||||
documents = serializable_db['documents']
|
||||
|
||||
# Records already use bytes as keys, so we can use them directly
|
||||
records = serializable_db["records"]
|
||||
|
||||
documents = serializable_db["documents"]
|
||||
|
||||
return Database(vectors, records, documents)
|
||||
|
||||
|
||||
def save(db: Database, database_file: Path) -> None:
|
||||
"""
|
||||
Saves the database to a file using pickle serialization.
|
||||
|
||||
|
||||
Args:
|
||||
db: The Database object to save
|
||||
database_file: Path where to save the database file
|
||||
"""
|
||||
# Ensure the directory exists
|
||||
database_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Create a serializable version of the database
|
||||
# Records already use bytes as keys, so we can use them directly
|
||||
serializable_db = {
|
||||
'vectors': [vector.tobytes() for vector in db.vectors],
|
||||
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
|
||||
'vector_shape': db.vectors[0].shape if db.vectors else (),
|
||||
'records': db.records, # Already uses bytes as keys
|
||||
'documents': db.documents,
|
||||
"vectors": [vector.tobytes() for vector in db.vectors],
|
||||
"vector_dtype": str(db.vectors[0].dtype) if db.vectors else "float64",
|
||||
"vector_shape": db.vectors[0].shape if db.vectors else (),
|
||||
"records": db.records, # Already uses bytes as keys
|
||||
"documents": db.documents,
|
||||
}
|
||||
|
||||
|
||||
# Save to file
|
||||
with open(database_file, 'wb') as f:
|
||||
with open(database_file, "wb") as f:
|
||||
pickle.dump(serializable_db, f)
|
||||
|
||||
|
||||
def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]:
|
||||
"""
|
||||
Query the database and return the N nearest records.
|
||||
|
||||
|
||||
Args:
|
||||
db: Database object or path to database file
|
||||
text: Query text to search for
|
||||
record_count: Number of nearest neighbors to return (default: 10)
|
||||
|
||||
|
||||
Returns:
|
||||
List of (distance, Record) tuples, sorted by distance (closest first)
|
||||
"""
|
||||
if isinstance(db, Path):
|
||||
db = load(db)
|
||||
|
||||
|
||||
# Generate embedding for query text
|
||||
query_vector = _embed(text)
|
||||
|
||||
|
||||
# Find nearest vectors
|
||||
# NOTE We're using euclidean distance as a metric of similarity,
|
||||
# there are some alternatives (cos or dot product) which may be used.
|
||||
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning)
|
||||
nearest_results = _find_nearest(db.vectors, query_vector, record_count)
|
||||
|
||||
|
||||
# Convert results to (distance, Record) tuples
|
||||
results: list[QueryResult] = []
|
||||
for distance, vector_idx in nearest_results:
|
||||
# Get the vector at this index
|
||||
vector = db.vectors[vector_idx]
|
||||
vector_bytes = vector.tobytes()
|
||||
|
||||
|
||||
# Look up the corresponding record
|
||||
if vector_bytes in db.records:
|
||||
record = db.records[vector_bytes]
|
||||
results.append(QueryResult(record, distance, db.documents[record.document_index].name))
|
||||
|
||||
results.append(
|
||||
QueryResult(record, distance, db.documents[record.document_index].name)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
||||
"""
|
||||
Adds a new document to the database. If path is given, do load, add, save.
|
||||
Loads PDF with PyMuPDF, splits by pages, and creates records and vectors.
|
||||
Uses multithreading for embedding generation.
|
||||
|
||||
|
||||
Args:
|
||||
db: Database object or path to database file
|
||||
file: Path to PDF file to add
|
||||
@@ -215,25 +228,25 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
||||
"""
|
||||
save_to_file = False
|
||||
database_file_path = None
|
||||
|
||||
|
||||
if isinstance(db, Path):
|
||||
database_file_path = db
|
||||
db = load(db)
|
||||
save_to_file = True
|
||||
|
||||
|
||||
if not file.exists():
|
||||
raise FileNotFoundError(f"File not found: {file}")
|
||||
|
||||
if file.suffix.lower() != '.pdf':
|
||||
|
||||
if file.suffix.lower() != ".pdf":
|
||||
raise ValueError(f"File must be a PDF: {file}")
|
||||
|
||||
|
||||
print(f"Processing PDF: {file}")
|
||||
|
||||
|
||||
document_index = len(db.documents)
|
||||
try:
|
||||
doc = pymupdf.open(file)
|
||||
print(f"PDF opened successfully: {len(doc)} pages")
|
||||
|
||||
|
||||
records: list[Record] = []
|
||||
chunk_size = 1024
|
||||
|
||||
@@ -243,13 +256,15 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
||||
if not text:
|
||||
print(f" Page {page_num + 1}: Skipped (empty)")
|
||||
continue
|
||||
|
||||
|
||||
# Simple chunking - split text into chunks of specified size
|
||||
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
|
||||
chunk = text[i:i + chunk_size]
|
||||
chunk = text[i : i + chunk_size]
|
||||
if chunk_stripped := chunk.strip(): # Only add non-empty chunks
|
||||
# page_num + 1 for use friendliness
|
||||
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
|
||||
records.append(
|
||||
Record(document_index, page_num + 1, chunk_stripped, chunk_idx)
|
||||
)
|
||||
doc.close()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing PDF {file}: {e}")
|
||||
@@ -261,34 +276,35 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
||||
|
||||
# TODO measure with GIL disabled to check if multithreading actually helps
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [ pool.submit(_vectorize_record, r) for r in records ]
|
||||
futures = [pool.submit(_vectorize_record, r) for r in records]
|
||||
for f in as_completed(futures):
|
||||
record, vector = f.result()
|
||||
db.records[vector.tobytes()] = record
|
||||
db.vectors.append(vector)
|
||||
|
||||
print(f"Successfully processed {file}: {len(records)} chunks")
|
||||
|
||||
|
||||
# Save database if we loaded it from file
|
||||
if save_to_file and database_file_path:
|
||||
save(db, database_file_path)
|
||||
print(f"Database saved to {database_file_path}")
|
||||
|
||||
|
||||
def get_document_path(db: Database | Path, document_index: int) -> Path:
|
||||
"""
|
||||
Get the file path of the document at the given index in the database.
|
||||
|
||||
|
||||
Args:
|
||||
db: Database object or path to database file
|
||||
document_index: Index of the document to retrieve
|
||||
|
||||
|
||||
Returns:
|
||||
Path to the document file
|
||||
"""
|
||||
if isinstance(db, Path):
|
||||
db = load(db)
|
||||
|
||||
|
||||
if document_index < 0 or document_index >= len(db.documents):
|
||||
raise IndexError(f"Document index out of range: {document_index}")
|
||||
|
||||
return db.documents[document_index]
|
||||
|
||||
return db.documents[document_index]
|
||||
|
||||
Reference in New Issue
Block a user