import pickle from pathlib import Path from dataclasses import dataclass from typing import Final from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np import pymupdf import ollama # TODO split to another file # # Types # type Vector = np.NDArray # np.NDArray[np.float32] ? type VectorBytes = bytes @dataclass(slots=True) class Record: document_index: int page: int text: str chunk: int = 0 # Chunk number within the page (0-indexed) @dataclass(slots=True) class QueryResult: record: Record distance: float document_name: str @dataclass(slots=True) class Database: """ "Vectors" hold the data for fast lookup of the vector, which can then be used to obtain record. TODO For faster nearest neighbour lookup we should use something else, e.g. kd-trees """ vectors: list[Vector] records: dict[VectorBytes, Record] documents: list[Path] # # Internal functions # def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]: """ Find the N nearest vectors to the query embedding. Args: vectors_db: List of vectors in the database query_vector: Query embedding vector n: Number of nearest neighbors to return Returns: List of (distance, index) tuples, sorted by distance (closest first) """ distances = [np.linalg.norm(x - query_vector) for x in vectors_db] # Get indices sorted by distance sorted_indices = np.argsort(distances) # Return top N results as (distance, index) tuples results = [] for i in range(min(count, len(sorted_indices))): idx = sorted_indices[i] results.append((float(distances[idx]), int(idx))) return results def _embed(text: str) -> Vector: """ Generate embedding vector for given text. """ MODEL: Final[str] = "nomic-embed-text" return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"]) def _vectorize_record(record: Record) -> tuple[Record, Vector]: return record, _embed(record.text) # # High-level (exported) functions # def create_dummy() -> Database: db_length: Final[int] = 10 vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)] records = { vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors } return Database(vectors, records, [Path("dummy")]) def create_empty() -> Database: """ Creates a new empty database with no vectors or records. Returns: Empty Database object """ return Database(vectors=[], records={}, documents=[]) def load(database_file: Path) -> Database: """ Loads a database from the given file. Args: database_file: Path to the database file Returns: Database object loaded from file """ if not database_file.exists(): raise FileNotFoundError(f"Database file not found: {database_file}") with open(database_file, 'rb') as f: serializable_db = pickle.load(f) # Reconstruct vectors from bytes vectors = [] vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64')) vector_shape = serializable_db.get('vector_shape', ()) for vector_bytes in serializable_db['vectors']: vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape) vectors.append(vector) # Records already use bytes as keys, so we can use them directly records = serializable_db['records'] documents = serializable_db['documents'] return Database(vectors, records, documents) def save(db: Database, database_file: Path) -> None: """ Saves the database to a file using pickle serialization. Args: db: The Database object to save database_file: Path where to save the database file """ # Ensure the directory exists database_file.parent.mkdir(parents=True, exist_ok=True) # Create a serializable version of the database # Records already use bytes as keys, so we can use them directly serializable_db = { 'vectors': [vector.tobytes() for vector in db.vectors], 'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64', 'vector_shape': db.vectors[0].shape if db.vectors else (), 'records': db.records, # Already uses bytes as keys 'documents': db.documents, } # Save to file with open(database_file, 'wb') as f: pickle.dump(serializable_db, f) def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]: """ Query the database and return the N nearest records. Args: db: Database object or path to database file text: Query text to search for record_count: Number of nearest neighbors to return (default: 10) Returns: List of (distance, Record) tuples, sorted by distance (closest first) """ if isinstance(db, Path): db = load(db) # Generate embedding for query text query_vector = _embed(text) # Find nearest vectors # NOTE We're using euclidean distance as a metric of similarity, # there are some alternatives (cos or dot product) which may be used. # See https://en.wikipedia.org/wiki/Embedding_(machine_learning) nearest_results = _find_nearest(db.vectors, query_vector, record_count) # Convert results to (distance, Record) tuples results: list[QueryResult] = [] for distance, vector_idx in nearest_results: # Get the vector at this index vector = db.vectors[vector_idx] vector_bytes = vector.tobytes() # Look up the corresponding record if vector_bytes in db.records: record = db.records[vector_bytes] results.append(QueryResult(record, distance, db.documents[record.document_index].name)) return results def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: """ Adds a new document to the database. If path is given, do load, add, save. Loads PDF with PyMuPDF, splits by pages, and creates records and vectors. Uses multithreading for embedding generation. Args: db: Database object or path to database file file: Path to PDF file to add max_workers: Maximum number of threads for parallel processing """ save_to_file = False database_file_path = None if isinstance(db, Path): database_file_path = db db = load(db) save_to_file = True if not file.exists(): raise FileNotFoundError(f"File not found: {file}") if file.suffix.lower() != '.pdf': raise ValueError(f"File must be a PDF: {file}") print(f"Processing PDF: {file}") document_index = len(db.documents) try: doc = pymupdf.open(file) print(f"PDF opened successfully: {len(doc)} pages") records: list[Record] = [] chunk_size = 1024 for page_num in range(len(doc)): page = doc[page_num] text = page.get_text().strip() if not text: print(f" Page {page_num + 1}: Skipped (empty)") continue # Simple chunking - split text into chunks of specified size for chunk_idx, i in enumerate(range(0, len(text), chunk_size)): chunk = text[i:i + chunk_size] if chunk_stripped := chunk.strip(): # Only add non-empty chunks # page_num + 1 for use friendliness records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx)) doc.close() except Exception as e: raise RuntimeError(f"Error processing PDF {file}: {e}") # Process chunks in parallel print(f"Processing {len(records)} chunks with {max_workers} workers...") db.documents.append(file) # TODO measure with GIL disabled to check if multithreading actually helps with ThreadPoolExecutor(max_workers=max_workers) as pool: futures = [ pool.submit(_vectorize_record, r) for r in records ] for f in as_completed(futures): record, vector = f.result() db.records[vector.tobytes()] = record db.vectors.append(vector) print(f"Successfully processed {file}: {len(records)} chunks") # Save database if we loaded it from file if save_to_file and database_file_path: save(db, database_file_path) print(f"Database saved to {database_file_path}") def get_document_path(db: Database | Path, document_index: int) -> Path: """ Get the file path of the document at the given index in the database. Args: db: Database object or path to database file document_index: Index of the document to retrieve Returns: Path to the document file """ if isinstance(db, Path): db = load(db) if document_index < 0 or document_index >= len(db.documents): raise IndexError(f"Document index out of range: {document_index}") return db.documents[document_index]