329 lines
11 KiB
Python
329 lines
11 KiB
Python
import pickle
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import Final
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import numpy as np
|
|
import pymupdf
|
|
import ollama # TODO split to another file
|
|
|
|
#
|
|
# Types
|
|
#
|
|
|
|
type Vector = np.NDArray # np.NDArray[np.float32] ?
|
|
type VectorBytes = bytes
|
|
|
|
@dataclass(slots=True)
|
|
class Record:
|
|
document: Path
|
|
page: int
|
|
text: str
|
|
chunk: int = 0 # Chunk number within the page (0-indexed)
|
|
|
|
@dataclass(slots=True)
|
|
class Database:
|
|
"""
|
|
"Vectors" hold the data for fast lookup of the vector,
|
|
which can then be used to obtain record.
|
|
TODO For faster nearest neighbour lookup we should use something else,
|
|
e.g. kd-trees
|
|
TODO The resulting db is huge (50 MB for 4 pdfs), contains a lot of duplicit uncompressed text.
|
|
We should at least de-duplicate the document path.
|
|
"""
|
|
vectors: list[Vector]
|
|
records: dict[VectorBytes, Record]
|
|
|
|
|
|
#
|
|
# Internal functions
|
|
#
|
|
|
|
|
|
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]:
|
|
"""
|
|
Find the N nearest vectors to the query embedding.
|
|
|
|
Args:
|
|
vectors_db: List of vectors in the database
|
|
query_vector: Query embedding vector
|
|
n: Number of nearest neighbors to return
|
|
|
|
Returns:
|
|
List of (distance, index) tuples, sorted by distance (closest first)
|
|
"""
|
|
distances = [np.linalg.norm(x - query_vector) for x in vectors_db]
|
|
|
|
# Get indices sorted by distance
|
|
sorted_indices = np.argsort(distances)
|
|
|
|
# Return top N results as (distance, index) tuples
|
|
results = []
|
|
for i in range(min(count, len(sorted_indices))):
|
|
idx = sorted_indices[i]
|
|
results.append((float(distances[idx]), int(idx)))
|
|
|
|
return results
|
|
|
|
def _embed(text: str) -> Vector:
|
|
"""
|
|
Generate embedding vector for given text.
|
|
"""
|
|
MODEL: Final[str] = "nomic-embed-text"
|
|
return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"])
|
|
|
|
|
|
def _process_chunk(chunk_data: tuple) -> tuple[Vector, Record]:
|
|
"""
|
|
Process a single chunk: generate embedding and create record.
|
|
Used for multithreading.
|
|
|
|
Args:
|
|
chunk_data: Tuple of (file_path, page_num, chunk_idx, chunk_text)
|
|
|
|
Returns:
|
|
Tuple of (vector, record)
|
|
"""
|
|
file_path, page_num, chunk_idx, chunk_text = chunk_data
|
|
|
|
# Generate embedding vector for this chunk
|
|
vector = _embed(chunk_text)
|
|
|
|
# Create record for this chunk
|
|
record = Record(
|
|
document=file_path,
|
|
page=page_num + 1, # 1-indexed for user-friendliness
|
|
text=chunk_text,
|
|
chunk=chunk_idx + 1 # 1-indexed for user-friendliness
|
|
)
|
|
|
|
return vector, record
|
|
|
|
|
|
#
|
|
# High-level (exported) functions
|
|
#
|
|
|
|
def create_dummy() -> Database:
|
|
db_length: Final[int] = 10
|
|
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
|
|
records = {
|
|
vector.tobytes(): Record(Path("dummy"), 1, "Lorem my ipsum", 1) for vector in vectors
|
|
}
|
|
return Database(vectors, records)
|
|
|
|
|
|
def create_empty() -> Database:
|
|
"""
|
|
Creates a new empty database with no vectors or records.
|
|
|
|
Returns:
|
|
Empty Database object
|
|
"""
|
|
return Database(vectors=[], records={})
|
|
|
|
|
|
def load(database_file: Path) -> Database:
|
|
"""
|
|
Loads a database from the given file.
|
|
|
|
Args:
|
|
database_file: Path to the database file
|
|
|
|
Returns:
|
|
Database object loaded from file
|
|
"""
|
|
if not database_file.exists():
|
|
raise FileNotFoundError(f"Database file not found: {database_file}")
|
|
|
|
with open(database_file, 'rb') as f:
|
|
serializable_db = pickle.load(f)
|
|
|
|
# Reconstruct vectors from bytes
|
|
vectors = []
|
|
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
|
|
vector_shape = serializable_db.get('vector_shape', ())
|
|
|
|
for vector_bytes in serializable_db['vectors']:
|
|
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
|
|
vectors.append(vector)
|
|
|
|
# Records already use bytes as keys, so we can use them directly
|
|
records = serializable_db['records']
|
|
|
|
return Database(vectors, records)
|
|
|
|
def save(db: Database, database_file: Path) -> None:
|
|
"""
|
|
Saves the database to a file using pickle serialization.
|
|
|
|
Args:
|
|
db: The Database object to save
|
|
database_file: Path where to save the database file
|
|
"""
|
|
# Ensure the directory exists
|
|
database_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create a serializable version of the database
|
|
# Records already use bytes as keys, so we can use them directly
|
|
serializable_db = {
|
|
'vectors': [vector.tobytes() for vector in db.vectors],
|
|
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
|
|
'vector_shape': db.vectors[0].shape if db.vectors else (),
|
|
'records': db.records # Already uses bytes as keys
|
|
}
|
|
|
|
# Save to file
|
|
with open(database_file, 'wb') as f:
|
|
pickle.dump(serializable_db, f)
|
|
|
|
|
|
def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[float, Record]]:
|
|
"""
|
|
Query the database and return the N nearest records.
|
|
|
|
Args:
|
|
db: Database object or path to database file
|
|
text: Query text to search for
|
|
record_count: Number of nearest neighbors to return (default: 10)
|
|
|
|
Returns:
|
|
List of (distance, Record) tuples, sorted by distance (closest first)
|
|
"""
|
|
if isinstance(db, Path):
|
|
db = load(db)
|
|
|
|
# Generate embedding for query text
|
|
query_vector = _embed(text)
|
|
|
|
# Find nearest vectors
|
|
# NOTE We're using euclidean distance as a metric of similarity,
|
|
# there are some alternatives (cos or dot product) which may be used.
|
|
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning)
|
|
nearest_results = _find_nearest(db.vectors, query_vector, record_count)
|
|
|
|
# Convert results to (distance, Record) tuples
|
|
results = []
|
|
for distance, vector_idx in nearest_results:
|
|
# Get the vector at this index
|
|
vector = db.vectors[vector_idx]
|
|
vector_bytes = vector.tobytes()
|
|
|
|
# Look up the corresponding record
|
|
if vector_bytes in db.records:
|
|
record = db.records[vector_bytes]
|
|
results.append((distance, record))
|
|
|
|
return results
|
|
|
|
def add_document(db: Database | Path, file: Path, max_workers: int = 1) -> None:
|
|
"""
|
|
Adds a new document to the database. If path is given, do load, add, save.
|
|
Loads PDF with PyMuPDF, splits by pages, and creates records and vectors.
|
|
Uses multithreading for embedding generation.
|
|
|
|
Args:
|
|
db: Database object or path to database file
|
|
file: Path to PDF file to add
|
|
max_workers: Maximum number of threads for parallel processing
|
|
|
|
"""
|
|
save_to_file = False
|
|
database_file_path = None
|
|
|
|
if isinstance(db, Path):
|
|
database_file_path = db
|
|
db = load(db)
|
|
save_to_file = True
|
|
|
|
# Validate that the file exists and is a PDF
|
|
if not file.exists():
|
|
raise FileNotFoundError(f"File not found: {file}")
|
|
|
|
if file.suffix.lower() != '.pdf':
|
|
raise ValueError(f"File must be a PDF: {file}")
|
|
|
|
print(f"Processing PDF: {file}")
|
|
|
|
# Open PDF with PyMuPDF
|
|
try:
|
|
doc = pymupdf.open(file)
|
|
print(f"PDF opened successfully: {len(doc)} pages")
|
|
|
|
# Collect all chunks to process
|
|
chunk_tasks = []
|
|
|
|
# Process each page to collect chunks
|
|
for page_num in range(len(doc)):
|
|
page = doc[page_num]
|
|
|
|
# Extract text from page
|
|
text = page.get_text().strip()
|
|
|
|
# Skip empty pages
|
|
if not text:
|
|
print(f" Page {page_num + 1}: Skipped (empty)")
|
|
continue
|
|
|
|
#print(f" Page {page_num + 1}: {len(text)} characters")
|
|
|
|
# Split page text into chunks of 1024 characters
|
|
chunk_size = 1024
|
|
chunks = []
|
|
|
|
# Simple chunking - split text into chunks of specified size
|
|
for i in range(0, len(text), chunk_size):
|
|
chunk = text[i:i + chunk_size]
|
|
if chunk.strip(): # Only add non-empty chunks
|
|
chunks.append(chunk.strip())
|
|
|
|
#print(f" Split into {len(chunks)} chunks")
|
|
|
|
# Add chunk tasks for parallel processing
|
|
for chunk_idx, chunk_text in enumerate(chunks):
|
|
chunk_tasks.append((file, page_num, chunk_idx, chunk_text))
|
|
|
|
doc.close()
|
|
|
|
# Process chunks in parallel
|
|
print(f"Processing {len(chunk_tasks)} chunks with {max_workers} workers...")
|
|
|
|
# TODO measure with GIL disabled to check if multithreading actually helps
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
# Submit all chunk processing tasks
|
|
future_to_chunk = {
|
|
executor.submit(_process_chunk, chunk_data): chunk_data
|
|
for chunk_data in chunk_tasks
|
|
}
|
|
|
|
# Collect results as they complete
|
|
completed_chunks = 0
|
|
for future in as_completed(future_to_chunk):
|
|
try:
|
|
vector, record = future.result()
|
|
|
|
# Convert vector to bytes for use as dictionary key
|
|
vector_bytes = vector.tobytes()
|
|
|
|
# Add to database
|
|
db.vectors.append(vector)
|
|
db.records[vector_bytes] = record
|
|
|
|
completed_chunks += 1
|
|
# if completed_chunks % 10 == 0 or completed_chunks == len(chunk_tasks):
|
|
# print(f" Completed {completed_chunks}/{len(chunk_tasks)} chunks")
|
|
|
|
except Exception as e:
|
|
chunk_data = future_to_chunk[future]
|
|
print(f" Error processing chunk {chunk_data}: {e}")
|
|
|
|
print(f"Successfully processed {file}: {len(chunk_tasks)} chunks")
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error processing PDF {file}: {e}")
|
|
|
|
# Save database if we loaded it from file
|
|
if save_to_file and database_file_path:
|
|
save(db, database_file_path)
|
|
print(f"Database saved to {database_file_path}") |