Files
semantic_doc_search/db.py
2025-11-03 13:48:53 +01:00

329 lines
11 KiB
Python

import pickle
from pathlib import Path
from dataclasses import dataclass
from typing import Final
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import pymupdf
import ollama # TODO split to another file
#
# Types
#
type Vector = np.NDArray # np.NDArray[np.float32] ?
type VectorBytes = bytes
@dataclass(slots=True)
class Record:
document: Path
page: int
text: str
chunk: int = 0 # Chunk number within the page (0-indexed)
@dataclass(slots=True)
class Database:
"""
"Vectors" hold the data for fast lookup of the vector,
which can then be used to obtain record.
TODO For faster nearest neighbour lookup we should use something else,
e.g. kd-trees
TODO The resulting db is huge (50 MB for 4 pdfs), contains a lot of duplicit uncompressed text.
We should at least de-duplicate the document path.
"""
vectors: list[Vector]
records: dict[VectorBytes, Record]
#
# Internal functions
#
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]:
"""
Find the N nearest vectors to the query embedding.
Args:
vectors_db: List of vectors in the database
query_vector: Query embedding vector
n: Number of nearest neighbors to return
Returns:
List of (distance, index) tuples, sorted by distance (closest first)
"""
distances = [np.linalg.norm(x - query_vector) for x in vectors_db]
# Get indices sorted by distance
sorted_indices = np.argsort(distances)
# Return top N results as (distance, index) tuples
results = []
for i in range(min(count, len(sorted_indices))):
idx = sorted_indices[i]
results.append((float(distances[idx]), int(idx)))
return results
def _embed(text: str) -> Vector:
"""
Generate embedding vector for given text.
"""
MODEL: Final[str] = "nomic-embed-text"
return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"])
def _process_chunk(chunk_data: tuple) -> tuple[Vector, Record]:
"""
Process a single chunk: generate embedding and create record.
Used for multithreading.
Args:
chunk_data: Tuple of (file_path, page_num, chunk_idx, chunk_text)
Returns:
Tuple of (vector, record)
"""
file_path, page_num, chunk_idx, chunk_text = chunk_data
# Generate embedding vector for this chunk
vector = _embed(chunk_text)
# Create record for this chunk
record = Record(
document=file_path,
page=page_num + 1, # 1-indexed for user-friendliness
text=chunk_text,
chunk=chunk_idx + 1 # 1-indexed for user-friendliness
)
return vector, record
#
# High-level (exported) functions
#
def create_dummy() -> Database:
db_length: Final[int] = 10
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
records = {
vector.tobytes(): Record(Path("dummy"), 1, "Lorem my ipsum", 1) for vector in vectors
}
return Database(vectors, records)
def create_empty() -> Database:
"""
Creates a new empty database with no vectors or records.
Returns:
Empty Database object
"""
return Database(vectors=[], records={})
def load(database_file: Path) -> Database:
"""
Loads a database from the given file.
Args:
database_file: Path to the database file
Returns:
Database object loaded from file
"""
if not database_file.exists():
raise FileNotFoundError(f"Database file not found: {database_file}")
with open(database_file, 'rb') as f:
serializable_db = pickle.load(f)
# Reconstruct vectors from bytes
vectors = []
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
vector_shape = serializable_db.get('vector_shape', ())
for vector_bytes in serializable_db['vectors']:
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
vectors.append(vector)
# Records already use bytes as keys, so we can use them directly
records = serializable_db['records']
return Database(vectors, records)
def save(db: Database, database_file: Path) -> None:
"""
Saves the database to a file using pickle serialization.
Args:
db: The Database object to save
database_file: Path where to save the database file
"""
# Ensure the directory exists
database_file.parent.mkdir(parents=True, exist_ok=True)
# Create a serializable version of the database
# Records already use bytes as keys, so we can use them directly
serializable_db = {
'vectors': [vector.tobytes() for vector in db.vectors],
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
'vector_shape': db.vectors[0].shape if db.vectors else (),
'records': db.records # Already uses bytes as keys
}
# Save to file
with open(database_file, 'wb') as f:
pickle.dump(serializable_db, f)
def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[float, Record]]:
"""
Query the database and return the N nearest records.
Args:
db: Database object or path to database file
text: Query text to search for
record_count: Number of nearest neighbors to return (default: 10)
Returns:
List of (distance, Record) tuples, sorted by distance (closest first)
"""
if isinstance(db, Path):
db = load(db)
# Generate embedding for query text
query_vector = _embed(text)
# Find nearest vectors
# NOTE We're using euclidean distance as a metric of similarity,
# there are some alternatives (cos or dot product) which may be used.
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning)
nearest_results = _find_nearest(db.vectors, query_vector, record_count)
# Convert results to (distance, Record) tuples
results = []
for distance, vector_idx in nearest_results:
# Get the vector at this index
vector = db.vectors[vector_idx]
vector_bytes = vector.tobytes()
# Look up the corresponding record
if vector_bytes in db.records:
record = db.records[vector_bytes]
results.append((distance, record))
return results
def add_document(db: Database | Path, file: Path, max_workers: int = 1) -> None:
"""
Adds a new document to the database. If path is given, do load, add, save.
Loads PDF with PyMuPDF, splits by pages, and creates records and vectors.
Uses multithreading for embedding generation.
Args:
db: Database object or path to database file
file: Path to PDF file to add
max_workers: Maximum number of threads for parallel processing
"""
save_to_file = False
database_file_path = None
if isinstance(db, Path):
database_file_path = db
db = load(db)
save_to_file = True
# Validate that the file exists and is a PDF
if not file.exists():
raise FileNotFoundError(f"File not found: {file}")
if file.suffix.lower() != '.pdf':
raise ValueError(f"File must be a PDF: {file}")
print(f"Processing PDF: {file}")
# Open PDF with PyMuPDF
try:
doc = pymupdf.open(file)
print(f"PDF opened successfully: {len(doc)} pages")
# Collect all chunks to process
chunk_tasks = []
# Process each page to collect chunks
for page_num in range(len(doc)):
page = doc[page_num]
# Extract text from page
text = page.get_text().strip()
# Skip empty pages
if not text:
print(f" Page {page_num + 1}: Skipped (empty)")
continue
#print(f" Page {page_num + 1}: {len(text)} characters")
# Split page text into chunks of 1024 characters
chunk_size = 1024
chunks = []
# Simple chunking - split text into chunks of specified size
for i in range(0, len(text), chunk_size):
chunk = text[i:i + chunk_size]
if chunk.strip(): # Only add non-empty chunks
chunks.append(chunk.strip())
#print(f" Split into {len(chunks)} chunks")
# Add chunk tasks for parallel processing
for chunk_idx, chunk_text in enumerate(chunks):
chunk_tasks.append((file, page_num, chunk_idx, chunk_text))
doc.close()
# Process chunks in parallel
print(f"Processing {len(chunk_tasks)} chunks with {max_workers} workers...")
# TODO measure with GIL disabled to check if multithreading actually helps
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all chunk processing tasks
future_to_chunk = {
executor.submit(_process_chunk, chunk_data): chunk_data
for chunk_data in chunk_tasks
}
# Collect results as they complete
completed_chunks = 0
for future in as_completed(future_to_chunk):
try:
vector, record = future.result()
# Convert vector to bytes for use as dictionary key
vector_bytes = vector.tobytes()
# Add to database
db.vectors.append(vector)
db.records[vector_bytes] = record
completed_chunks += 1
# if completed_chunks % 10 == 0 or completed_chunks == len(chunk_tasks):
# print(f" Completed {completed_chunks}/{len(chunk_tasks)} chunks")
except Exception as e:
chunk_data = future_to_chunk[future]
print(f" Error processing chunk {chunk_data}: {e}")
print(f"Successfully processed {file}: {len(chunk_tasks)} chunks")
except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}")
# Save database if we loaded it from file
if save_to_file and database_file_path:
save(db, database_file_path)
print(f"Database saved to {database_file_path}")