Files
semantic_doc_search/db.py
2025-11-05 14:41:26 +01:00

277 lines
8.9 KiB
Python

import pickle
from pathlib import Path
from dataclasses import dataclass
from typing import Final
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import pymupdf
import ollama # TODO split to another file
#
# Types
#
type Vector = np.NDArray # np.NDArray[np.float32] ?
type VectorBytes = bytes
@dataclass(slots=True)
class Record:
document_index: int
page: int
text: str
chunk: int = 0 # Chunk number within the page (0-indexed)
@dataclass(slots=True)
class QueryResult:
record: Record
distance: float
document: Path
@dataclass(slots=True)
class Database:
"""
"Vectors" hold the data for fast lookup of the vector,
which can then be used to obtain record.
TODO For faster nearest neighbour lookup we should use something else,
e.g. kd-trees
TODO The resulting db is huge (50 MB for 4 pdfs), contains a lot of duplicit uncompressed text.
We should at least de-duplicate the document path.
"""
vectors: list[Vector]
records: dict[VectorBytes, Record]
documents: list[Path]
#
# Internal functions
#
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]:
"""
Find the N nearest vectors to the query embedding.
Args:
vectors_db: List of vectors in the database
query_vector: Query embedding vector
n: Number of nearest neighbors to return
Returns:
List of (distance, index) tuples, sorted by distance (closest first)
"""
distances = [np.linalg.norm(x - query_vector) for x in vectors_db]
# Get indices sorted by distance
sorted_indices = np.argsort(distances)
# Return top N results as (distance, index) tuples
results = []
for i in range(min(count, len(sorted_indices))):
idx = sorted_indices[i]
results.append((float(distances[idx]), int(idx)))
return results
def _embed(text: str) -> Vector:
"""
Generate embedding vector for given text.
"""
MODEL: Final[str] = "nomic-embed-text"
return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"])
def _vectorize_record(record: Record) -> tuple[Record, Vector]:
return record, _embed(record.text)
#
# High-level (exported) functions
#
def create_dummy() -> Database:
db_length: Final[int] = 10
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
records = {
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
}
return Database(vectors, records, [Path("dummy")])
def create_empty() -> Database:
"""
Creates a new empty database with no vectors or records.
Returns:
Empty Database object
"""
return Database(vectors=[], records={}, documents=[])
def load(database_file: Path) -> Database:
"""
Loads a database from the given file.
Args:
database_file: Path to the database file
Returns:
Database object loaded from file
"""
if not database_file.exists():
raise FileNotFoundError(f"Database file not found: {database_file}")
with open(database_file, 'rb') as f:
serializable_db = pickle.load(f)
# Reconstruct vectors from bytes
vectors = []
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
vector_shape = serializable_db.get('vector_shape', ())
for vector_bytes in serializable_db['vectors']:
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
vectors.append(vector)
# Records already use bytes as keys, so we can use them directly
records = serializable_db['records']
documents = serializable_db['documents']
return Database(vectors, records, documents)
def save(db: Database, database_file: Path) -> None:
"""
Saves the database to a file using pickle serialization.
Args:
db: The Database object to save
database_file: Path where to save the database file
"""
# Ensure the directory exists
database_file.parent.mkdir(parents=True, exist_ok=True)
# Create a serializable version of the database
# Records already use bytes as keys, so we can use them directly
serializable_db = {
'vectors': [vector.tobytes() for vector in db.vectors],
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
'vector_shape': db.vectors[0].shape if db.vectors else (),
'records': db.records, # Already uses bytes as keys
'documents': db.documents,
}
# Save to file
with open(database_file, 'wb') as f:
pickle.dump(serializable_db, f)
def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]:
"""
Query the database and return the N nearest records.
Args:
db: Database object or path to database file
text: Query text to search for
record_count: Number of nearest neighbors to return (default: 10)
Returns:
List of (distance, Record) tuples, sorted by distance (closest first)
"""
if isinstance(db, Path):
db = load(db)
# Generate embedding for query text
query_vector = _embed(text)
# Find nearest vectors
# NOTE We're using euclidean distance as a metric of similarity,
# there are some alternatives (cos or dot product) which may be used.
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning)
nearest_results = _find_nearest(db.vectors, query_vector, record_count)
# Convert results to (distance, Record) tuples
results: list[QueryResult] = []
for distance, vector_idx in nearest_results:
# Get the vector at this index
vector = db.vectors[vector_idx]
vector_bytes = vector.tobytes()
# Look up the corresponding record
if vector_bytes in db.records:
record = db.records[vector_bytes]
results.append(QueryResult(record, distance, db.documents[record.document_index]))
return results
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
"""
Adds a new document to the database. If path is given, do load, add, save.
Loads PDF with PyMuPDF, splits by pages, and creates records and vectors.
Uses multithreading for embedding generation.
Args:
db: Database object or path to database file
file: Path to PDF file to add
max_workers: Maximum number of threads for parallel processing
"""
save_to_file = False
database_file_path = None
if isinstance(db, Path):
database_file_path = db
db = load(db)
save_to_file = True
if not file.exists():
raise FileNotFoundError(f"File not found: {file}")
if file.suffix.lower() != '.pdf':
raise ValueError(f"File must be a PDF: {file}")
print(f"Processing PDF: {file}")
document_index = len(db.documents)
try:
doc = pymupdf.open(file)
print(f"PDF opened successfully: {len(doc)} pages")
records: list[Record] = []
chunk_size = 1024
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text().strip()
if not text:
print(f" Page {page_num + 1}: Skipped (empty)")
continue
# Simple chunking - split text into chunks of specified size
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
chunk = text[i:i + chunk_size]
if chunk_stripped := chunk.strip(): # Only add non-empty chunks
# page_num + 1 for use friendliness
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
doc.close()
except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}")
# Process chunks in parallel
print(f"Processing {len(records)} chunks with {max_workers} workers...")
db.documents.append(file)
# TODO measure with GIL disabled to check if multithreading actually helps
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [ pool.submit(_vectorize_record, r) for r in records ]
for f in as_completed(futures):
record, vector = f.result()
db.records[vector.tobytes()] = record
db.vectors.append(vector)
print(f"Successfully processed {file}: {len(records)} chunks")
# Save database if we loaded it from file
if save_to_file and database_file_path:
save(db, database_file_path)
print(f"Database saved to {database_file_path}")