Change line endings to LF

This commit is contained in:
Jan Mrna
2025-11-06 10:46:31 +01:00
parent b9636dbd57
commit 2fb7a7d224
2 changed files with 367 additions and 367 deletions

586
db.py
View File

@@ -1,294 +1,294 @@
import pickle import pickle
from pathlib import Path from pathlib import Path
from dataclasses import dataclass from dataclasses import dataclass
from typing import Final from typing import Final
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np import numpy as np
import pymupdf import pymupdf
import ollama # TODO split to another file import ollama # TODO split to another file
# #
# Types # Types
# #
type Vector = np.NDArray # np.NDArray[np.float32] ? type Vector = np.NDArray # np.NDArray[np.float32] ?
type VectorBytes = bytes type VectorBytes = bytes
@dataclass(slots=True) @dataclass(slots=True)
class Record: class Record:
document_index: int document_index: int
page: int page: int
text: str text: str
chunk: int = 0 # Chunk number within the page (0-indexed) chunk: int = 0 # Chunk number within the page (0-indexed)
@dataclass(slots=True) @dataclass(slots=True)
class QueryResult: class QueryResult:
record: Record record: Record
distance: float distance: float
document_name: str document_name: str
@dataclass(slots=True) @dataclass(slots=True)
class Database: class Database:
""" """
"Vectors" hold the data for fast lookup of the vector, "Vectors" hold the data for fast lookup of the vector,
which can then be used to obtain record. which can then be used to obtain record.
TODO For faster nearest neighbour lookup we should use something else, TODO For faster nearest neighbour lookup we should use something else,
e.g. kd-trees e.g. kd-trees
""" """
vectors: list[Vector] vectors: list[Vector]
records: dict[VectorBytes, Record] records: dict[VectorBytes, Record]
documents: list[Path] documents: list[Path]
# #
# Internal functions # Internal functions
# #
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]: def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]:
""" """
Find the N nearest vectors to the query embedding. Find the N nearest vectors to the query embedding.
Args: Args:
vectors_db: List of vectors in the database vectors_db: List of vectors in the database
query_vector: Query embedding vector query_vector: Query embedding vector
n: Number of nearest neighbors to return n: Number of nearest neighbors to return
Returns: Returns:
List of (distance, index) tuples, sorted by distance (closest first) List of (distance, index) tuples, sorted by distance (closest first)
""" """
distances = [np.linalg.norm(x - query_vector) for x in vectors_db] distances = [np.linalg.norm(x - query_vector) for x in vectors_db]
# Get indices sorted by distance # Get indices sorted by distance
sorted_indices = np.argsort(distances) sorted_indices = np.argsort(distances)
# Return top N results as (distance, index) tuples # Return top N results as (distance, index) tuples
results = [] results = []
for i in range(min(count, len(sorted_indices))): for i in range(min(count, len(sorted_indices))):
idx = sorted_indices[i] idx = sorted_indices[i]
results.append((float(distances[idx]), int(idx))) results.append((float(distances[idx]), int(idx)))
return results return results
def _embed(text: str) -> Vector: def _embed(text: str) -> Vector:
""" """
Generate embedding vector for given text. Generate embedding vector for given text.
""" """
MODEL: Final[str] = "nomic-embed-text" MODEL: Final[str] = "nomic-embed-text"
return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"]) return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"])
def _vectorize_record(record: Record) -> tuple[Record, Vector]: def _vectorize_record(record: Record) -> tuple[Record, Vector]:
return record, _embed(record.text) return record, _embed(record.text)
# #
# High-level (exported) functions # High-level (exported) functions
# #
def create_dummy() -> Database: def create_dummy() -> Database:
db_length: Final[int] = 10 db_length: Final[int] = 10
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)] vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
records = { records = {
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
} }
return Database(vectors, records, [Path("dummy")]) return Database(vectors, records, [Path("dummy")])
def create_empty() -> Database: def create_empty() -> Database:
""" """
Creates a new empty database with no vectors or records. Creates a new empty database with no vectors or records.
Returns: Returns:
Empty Database object Empty Database object
""" """
return Database(vectors=[], records={}, documents=[]) return Database(vectors=[], records={}, documents=[])
def load(database_file: Path) -> Database: def load(database_file: Path) -> Database:
""" """
Loads a database from the given file. Loads a database from the given file.
Args: Args:
database_file: Path to the database file database_file: Path to the database file
Returns: Returns:
Database object loaded from file Database object loaded from file
""" """
if not database_file.exists(): if not database_file.exists():
raise FileNotFoundError(f"Database file not found: {database_file}") raise FileNotFoundError(f"Database file not found: {database_file}")
with open(database_file, 'rb') as f: with open(database_file, 'rb') as f:
serializable_db = pickle.load(f) serializable_db = pickle.load(f)
# Reconstruct vectors from bytes # Reconstruct vectors from bytes
vectors = [] vectors = []
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64')) vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
vector_shape = serializable_db.get('vector_shape', ()) vector_shape = serializable_db.get('vector_shape', ())
for vector_bytes in serializable_db['vectors']: for vector_bytes in serializable_db['vectors']:
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape) vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
vectors.append(vector) vectors.append(vector)
# Records already use bytes as keys, so we can use them directly # Records already use bytes as keys, so we can use them directly
records = serializable_db['records'] records = serializable_db['records']
documents = serializable_db['documents'] documents = serializable_db['documents']
return Database(vectors, records, documents) return Database(vectors, records, documents)
def save(db: Database, database_file: Path) -> None: def save(db: Database, database_file: Path) -> None:
""" """
Saves the database to a file using pickle serialization. Saves the database to a file using pickle serialization.
Args: Args:
db: The Database object to save db: The Database object to save
database_file: Path where to save the database file database_file: Path where to save the database file
""" """
# Ensure the directory exists # Ensure the directory exists
database_file.parent.mkdir(parents=True, exist_ok=True) database_file.parent.mkdir(parents=True, exist_ok=True)
# Create a serializable version of the database # Create a serializable version of the database
# Records already use bytes as keys, so we can use them directly # Records already use bytes as keys, so we can use them directly
serializable_db = { serializable_db = {
'vectors': [vector.tobytes() for vector in db.vectors], 'vectors': [vector.tobytes() for vector in db.vectors],
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64', 'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
'vector_shape': db.vectors[0].shape if db.vectors else (), 'vector_shape': db.vectors[0].shape if db.vectors else (),
'records': db.records, # Already uses bytes as keys 'records': db.records, # Already uses bytes as keys
'documents': db.documents, 'documents': db.documents,
} }
# Save to file # Save to file
with open(database_file, 'wb') as f: with open(database_file, 'wb') as f:
pickle.dump(serializable_db, f) pickle.dump(serializable_db, f)
def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]: def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]:
""" """
Query the database and return the N nearest records. Query the database and return the N nearest records.
Args: Args:
db: Database object or path to database file db: Database object or path to database file
text: Query text to search for text: Query text to search for
record_count: Number of nearest neighbors to return (default: 10) record_count: Number of nearest neighbors to return (default: 10)
Returns: Returns:
List of (distance, Record) tuples, sorted by distance (closest first) List of (distance, Record) tuples, sorted by distance (closest first)
""" """
if isinstance(db, Path): if isinstance(db, Path):
db = load(db) db = load(db)
# Generate embedding for query text # Generate embedding for query text
query_vector = _embed(text) query_vector = _embed(text)
# Find nearest vectors # Find nearest vectors
# NOTE We're using euclidean distance as a metric of similarity, # NOTE We're using euclidean distance as a metric of similarity,
# there are some alternatives (cos or dot product) which may be used. # there are some alternatives (cos or dot product) which may be used.
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning) # See https://en.wikipedia.org/wiki/Embedding_(machine_learning)
nearest_results = _find_nearest(db.vectors, query_vector, record_count) nearest_results = _find_nearest(db.vectors, query_vector, record_count)
# Convert results to (distance, Record) tuples # Convert results to (distance, Record) tuples
results: list[QueryResult] = [] results: list[QueryResult] = []
for distance, vector_idx in nearest_results: for distance, vector_idx in nearest_results:
# Get the vector at this index # Get the vector at this index
vector = db.vectors[vector_idx] vector = db.vectors[vector_idx]
vector_bytes = vector.tobytes() vector_bytes = vector.tobytes()
# Look up the corresponding record # Look up the corresponding record
if vector_bytes in db.records: if vector_bytes in db.records:
record = db.records[vector_bytes] record = db.records[vector_bytes]
results.append(QueryResult(record, distance, db.documents[record.document_index].name)) results.append(QueryResult(record, distance, db.documents[record.document_index].name))
return results return results
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
""" """
Adds a new document to the database. If path is given, do load, add, save. Adds a new document to the database. If path is given, do load, add, save.
Loads PDF with PyMuPDF, splits by pages, and creates records and vectors. Loads PDF with PyMuPDF, splits by pages, and creates records and vectors.
Uses multithreading for embedding generation. Uses multithreading for embedding generation.
Args: Args:
db: Database object or path to database file db: Database object or path to database file
file: Path to PDF file to add file: Path to PDF file to add
max_workers: Maximum number of threads for parallel processing max_workers: Maximum number of threads for parallel processing
""" """
save_to_file = False save_to_file = False
database_file_path = None database_file_path = None
if isinstance(db, Path): if isinstance(db, Path):
database_file_path = db database_file_path = db
db = load(db) db = load(db)
save_to_file = True save_to_file = True
if not file.exists(): if not file.exists():
raise FileNotFoundError(f"File not found: {file}") raise FileNotFoundError(f"File not found: {file}")
if file.suffix.lower() != '.pdf': if file.suffix.lower() != '.pdf':
raise ValueError(f"File must be a PDF: {file}") raise ValueError(f"File must be a PDF: {file}")
print(f"Processing PDF: {file}") print(f"Processing PDF: {file}")
document_index = len(db.documents) document_index = len(db.documents)
try: try:
doc = pymupdf.open(file) doc = pymupdf.open(file)
print(f"PDF opened successfully: {len(doc)} pages") print(f"PDF opened successfully: {len(doc)} pages")
records: list[Record] = [] records: list[Record] = []
chunk_size = 1024 chunk_size = 1024
for page_num in range(len(doc)): for page_num in range(len(doc)):
page = doc[page_num] page = doc[page_num]
text = page.get_text().strip() text = page.get_text().strip()
if not text: if not text:
print(f" Page {page_num + 1}: Skipped (empty)") print(f" Page {page_num + 1}: Skipped (empty)")
continue continue
# Simple chunking - split text into chunks of specified size # Simple chunking - split text into chunks of specified size
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)): for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
chunk = text[i:i + chunk_size] chunk = text[i:i + chunk_size]
if chunk_stripped := chunk.strip(): # Only add non-empty chunks if chunk_stripped := chunk.strip(): # Only add non-empty chunks
# page_num + 1 for use friendliness # page_num + 1 for use friendliness
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx)) records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
doc.close() doc.close()
except Exception as e: except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}") raise RuntimeError(f"Error processing PDF {file}: {e}")
# Process chunks in parallel # Process chunks in parallel
print(f"Processing {len(records)} chunks with {max_workers} workers...") print(f"Processing {len(records)} chunks with {max_workers} workers...")
db.documents.append(file) db.documents.append(file)
# TODO measure with GIL disabled to check if multithreading actually helps # TODO measure with GIL disabled to check if multithreading actually helps
with ThreadPoolExecutor(max_workers=max_workers) as pool: with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [ pool.submit(_vectorize_record, r) for r in records ] futures = [ pool.submit(_vectorize_record, r) for r in records ]
for f in as_completed(futures): for f in as_completed(futures):
record, vector = f.result() record, vector = f.result()
db.records[vector.tobytes()] = record db.records[vector.tobytes()] = record
db.vectors.append(vector) db.vectors.append(vector)
print(f"Successfully processed {file}: {len(records)} chunks") print(f"Successfully processed {file}: {len(records)} chunks")
# Save database if we loaded it from file # Save database if we loaded it from file
if save_to_file and database_file_path: if save_to_file and database_file_path:
save(db, database_file_path) save(db, database_file_path)
print(f"Database saved to {database_file_path}") print(f"Database saved to {database_file_path}")
def get_document_path(db: Database | Path, document_index: int) -> Path: def get_document_path(db: Database | Path, document_index: int) -> Path:
""" """
Get the file path of the document at the given index in the database. Get the file path of the document at the given index in the database.
Args: Args:
db: Database object or path to database file db: Database object or path to database file
document_index: Index of the document to retrieve document_index: Index of the document to retrieve
Returns: Returns:
Path to the document file Path to the document file
""" """
if isinstance(db, Path): if isinstance(db, Path):
db = load(db) db = load(db)
if document_index < 0 or document_index >= len(db.documents): if document_index < 0 or document_index >= len(db.documents):
raise IndexError(f"Document index out of range: {document_index}") raise IndexError(f"Document index out of range: {document_index}")
return db.documents[document_index] return db.documents[document_index]

View File

@@ -1,75 +1,75 @@
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
<title>Semantic Document Search</title> <title>Semantic Document Search</title>
<style> <style>
body { font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; } body { font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }
.search-box { margin-bottom: 20px; } .search-box { margin-bottom: 20px; }
input[type="text"] { width: 70%; padding: 10px; font-size: 16px; } input[type="text"] { width: 70%; padding: 10px; font-size: 16px; }
button { padding: 10px 20px; font-size: 16px; background: #007cba; color: white; border: none; cursor: pointer; } button { padding: 10px 20px; font-size: 16px; background: #007cba; color: white; border: none; cursor: pointer; }
button:hover { background: #005c8a; } button:hover { background: #005c8a; }
.result { border: 1px solid #ddd; margin: 10px 0; padding: 15px; border-radius: 5px; } .result { border: 1px solid #ddd; margin: 10px 0; padding: 15px; border-radius: 5px; }
.result-header { font-weight: bold; color: #333; margin-bottom: 10px; } .result-header { font-weight: bold; color: #333; margin-bottom: 10px; }
.result-text { background: #f9f9f9; padding: 10px; border-radius: 3px; } .result-text { background: #f9f9f9; padding: 10px; border-radius: 3px; }
.distance { color: #666; font-size: 0.9em; } .distance { color: #666; font-size: 0.9em; }
.document-link { color: #007cba; text-decoration: none; } .document-link { color: #007cba; text-decoration: none; }
.document-link:hover { text-decoration: underline; } .document-link:hover { text-decoration: underline; }
.no-results { text-align: center; color: #666; margin: 40px 0; } .no-results { text-align: center; color: #666; margin: 40px 0; }
.loading { text-align: center; color: #007cba; margin: 20px 0; } .loading { text-align: center; color: #007cba; margin: 20px 0; }
</style> </style>
</head> </head>
<body> <body>
<h1>🔍 Semantic Document Search</h1> <h1>🔍 Semantic Document Search</h1>
<div class="search-box"> <div class="search-box">
<form id="searchForm"> <form id="searchForm">
<input type="text" id="queryInput" placeholder="Enter your search query..." required> <input type="text" id="queryInput" placeholder="Enter your search query..." required>
<button type="submit">Search</button> <button type="submit">Search</button>
</form> </form>
</div> </div>
<div id="results"></div> <div id="results"></div>
<script> <script>
document.getElementById('searchForm').addEventListener('submit', async (e) => { document.getElementById('searchForm').addEventListener('submit', async (e) => {
e.preventDefault(); e.preventDefault();
const query = document.getElementById('queryInput').value; const query = document.getElementById('queryInput').value;
const resultsDiv = document.getElementById('results'); const resultsDiv = document.getElementById('results');
resultsDiv.innerHTML = '<div class="loading">Searching...</div>'; resultsDiv.innerHTML = '<div class="loading">Searching...</div>';
try { try {
const response = await fetch('/api/search', { const response = await fetch('/api/search', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ query: query }) body: JSON.stringify({ query: query })
}); });
const data = await response.json(); const data = await response.json();
if (data.error) { if (data.error) {
resultsDiv.innerHTML = `<div class="no-results">Error: ${data.error}</div>`; resultsDiv.innerHTML = `<div class="no-results">Error: ${data.error}</div>`;
return; return;
} }
if (data.results.length === 0) { if (data.results.length === 0) {
resultsDiv.innerHTML = '<div class="no-results">No results found.</div>'; resultsDiv.innerHTML = '<div class="no-results">No results found.</div>';
return; return;
} }
resultsDiv.innerHTML = data.results.map((result, i) => ` resultsDiv.innerHTML = data.results.map((result, i) => `
<div class="result"> <div class="result">
<div class="result-header"> <div class="result-header">
Result ${i + 1} - <a href="/file/${encodeURIComponent(result.document_index)}#page=${result.page}" class="document-link" target="_blank">${result.document_name}</a> Result ${i + 1} - <a href="/file/${encodeURIComponent(result.document_index)}#page=${result.page}" class="document-link" target="_blank">${result.document_name}</a>
<span class="distance">(Distance: ${result.distance.toFixed(4)})</span> <span class="distance">(Distance: ${result.distance.toFixed(4)})</span>
</div> </div>
<div>Page: ${result.page}, Chunk: ${result.chunk}</div> <div>Page: ${result.page}, Chunk: ${result.chunk}</div>
<div class="result-text">${result.text}</div> <div class="result-text">${result.text}</div>
</div> </div>
`).join(''); `).join('');
} catch (error) { } catch (error) {
resultsDiv.innerHTML = `<div class="no-results">Error: ${error.message}</div>`; resultsDiv.innerHTML = `<div class="no-results">Error: ${error.message}</div>`;
} }
}); });
</script> </script>
</body> </body>
</html> </html>