Compare commits

..

4 Commits

Author SHA1 Message Date
Jan Mrna
ee8a8ad170 Fixed warnings from pylint 2025-11-06 10:58:18 +01:00
Jan Mrna
7010edae44 Format document 2025-11-06 10:46:54 +01:00
Jan Mrna
e734a13a59 Change line endings to LF 2025-11-06 10:46:31 +01:00
Jan Mrna
788eebc916 Serve by file index, not full path 2025-11-06 10:45:42 +01:00
3 changed files with 562 additions and 462 deletions

102
db.py
View File

@@ -1,3 +1,7 @@
#pylint: disable=missing-class-docstring,invalid-name,broad-exception-caught
"""
Database module for semantic document search tool.
"""
import pickle
from pathlib import Path
from dataclasses import dataclass
@@ -6,15 +10,16 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import pymupdf
import ollama # TODO split to another file
import ollama
#
# Types
#
type Vector = np.NDArray # np.NDArray[np.float32] ?
type Vector = np.NDArray
type VectorBytes = bytes
@dataclass(slots=True)
class Record:
document_index: int
@@ -22,11 +27,13 @@ class Record:
text: str
chunk: int = 0 # Chunk number within the page (0-indexed)
@dataclass(slots=True)
class QueryResult:
record: Record
distance: float
document: Path
document_name: str
@dataclass(slots=True)
class Database:
@@ -36,6 +43,7 @@ class Database:
TODO For faster nearest neighbour lookup we should use something else,
e.g. kd-trees
"""
vectors: list[Vector]
records: dict[VectorBytes, Record]
documents: list[Path]
@@ -46,7 +54,9 @@ class Database:
#
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]:
def _find_nearest(
vectors_db: list[Vector], query_vector: Vector, count: int = 10
) -> list[tuple[float, int]]:
"""
Find the N nearest vectors to the query embedding.
@@ -71,6 +81,7 @@ def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 1
return results
def _embed(text: str) -> Vector:
"""
Generate embedding vector for given text.
@@ -82,11 +93,30 @@ def _embed(text: str) -> Vector:
def _vectorize_record(record: Record) -> tuple[Record, Vector]:
return record, _embed(record.text)
def test_embedding() -> bool:
"""
Test if embedding functionality is available and working.
Returns:
bool: True if embedding is working, False otherwise
"""
try:
_ = _embed("Test.")
return True
except Exception:
return False
#
# High-level (exported) functions
#
def create_dummy() -> Database:
"""
Create a dummy database for testing purposes.
"""
db_length: Final[int] = 10
vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
records = {
@@ -118,25 +148,26 @@ def load(database_file: Path) -> Database:
if not database_file.exists():
raise FileNotFoundError(f"Database file not found: {database_file}")
with open(database_file, 'rb') as f:
with open(database_file, "rb") as f:
serializable_db = pickle.load(f)
# Reconstruct vectors from bytes
vectors = []
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
vector_shape = serializable_db.get('vector_shape', ())
vector_dtype = np.dtype(serializable_db.get("vector_dtype", "float64"))
vector_shape = serializable_db.get("vector_shape", ())
for vector_bytes in serializable_db['vectors']:
for vector_bytes in serializable_db["vectors"]:
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
vectors.append(vector)
# Records already use bytes as keys, so we can use them directly
records = serializable_db['records']
records = serializable_db["records"]
documents = serializable_db['documents']
documents = serializable_db["documents"]
return Database(vectors, records, documents)
def save(db: Database, database_file: Path) -> None:
"""
Saves the database to a file using pickle serialization.
@@ -151,15 +182,15 @@ def save(db: Database, database_file: Path) -> None:
# Create a serializable version of the database
# Records already use bytes as keys, so we can use them directly
serializable_db = {
'vectors': [vector.tobytes() for vector in db.vectors],
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
'vector_shape': db.vectors[0].shape if db.vectors else (),
'records': db.records, # Already uses bytes as keys
'documents': db.documents,
"vectors": [vector.tobytes() for vector in db.vectors],
"vector_dtype": str(db.vectors[0].dtype) if db.vectors else "float64",
"vector_shape": db.vectors[0].shape if db.vectors else (),
"records": db.records, # Already uses bytes as keys
"documents": db.documents,
}
# Save to file
with open(database_file, 'wb') as f:
with open(database_file, "wb") as f:
pickle.dump(serializable_db, f)
@@ -197,10 +228,13 @@ def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryR
# Look up the corresponding record
if vector_bytes in db.records:
record = db.records[vector_bytes]
results.append(QueryResult(record, distance, db.documents[record.document_index]))
results.append(
QueryResult(record, distance, db.documents[record.document_index].name)
)
return results
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
"""
Adds a new document to the database. If path is given, do load, add, save.
@@ -224,7 +258,7 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
if not file.exists():
raise FileNotFoundError(f"File not found: {file}")
if file.suffix.lower() != '.pdf':
if file.suffix.lower() != ".pdf":
raise ValueError(f"File must be a PDF: {file}")
print(f"Processing PDF: {file}")
@@ -237,29 +271,29 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
records: list[Record] = []
chunk_size = 1024
for page_num in range(len(doc)):
page = doc[page_num]
for page_num, page in enumerate(doc):
text = page.get_text().strip()
if not text:
print(f" Page {page_num + 1}: Skipped (empty)")
continue
# Simple chunking - split text into chunks of specified size
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
chunk = text[i : i + chunk_size]
if chunk_stripped := chunk.strip(): # Only add non-empty chunks
# page_num + 1 for use friendliness
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
records.append(
Record(document_index, page_num + 1, chunk_stripped, chunk_idx)
)
doc.close()
except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}")
raise RuntimeError(f"Error processing PDF {file}: {e}") from e
# Process chunks in parallel
print(f"Processing {len(records)} chunks with {max_workers} workers...")
db.documents.append(file)
# TODO measure with GIL disabled to check if multithreading actually helps
# NOTE this will only help with GIL disabled
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_vectorize_record, r) for r in records]
for f in as_completed(futures):
@@ -273,3 +307,23 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
if save_to_file and database_file_path:
save(db, database_file_path)
print(f"Database saved to {database_file_path}")
def get_document_path(db: Database | Path, document_index: int) -> Path:
"""
Get the file path of the document at the given index in the database.
Args:
db: Database object or path to database file
document_index: Index of the document to retrieve
Returns:
Path to the document file
"""
if isinstance(db, Path):
db = load(db)
if document_index < 0 or document_index >= len(db.documents):
raise IndexError(f"Document index out of range: {document_index}")
return db.documents[document_index]

184
main.py
View File

@@ -1,14 +1,21 @@
import argparse
#pylint: disable=broad-exception-caught
"""
Semantic search tool main script.
Provides command-line interface and web server for creating, adding and querying the database.
"""
import sys
from pathlib import Path
import tempfile
import numpy as np
import argparse
from typing import Final
from pathlib import Path
import numpy as np
import db
DEFAULT_DB_PATH: Final[Path] = Path("db.pkl")
def test_database():
"""Test database save/load functionality by creating, saving, loading and comparing."""
print("=== Database Test ===")
@@ -16,15 +23,23 @@ def test_database():
# Create dummy database
print("1. Creating dummy database...")
original_db = db.create_dummy()
print(f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records")
print(
f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records"
)
# Print some details about the original database
print(" First vector shape:", original_db.vectors[0].shape if original_db.vectors else "No vectors")
print(" Sample vector:", original_db.vectors[0][:4] if original_db.vectors else "No vectors")
print(
" First vector shape:",
original_db.vectors[0].shape if original_db.vectors else "No vectors",
)
print(
" Sample vector:",
original_db.vectors[0][:4] if original_db.vectors else "No vectors",
)
print(" Sample record keys (first 3):", list(original_db.records.keys())[:3])
# Create temporary file for testing
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file:
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file:
test_file = Path(tmp_file.name)
try:
@@ -36,7 +51,9 @@ def test_database():
# Load database
print(f"\n3. Loading database from {test_file}...")
loaded_db = db.load(test_file)
print(f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records")
print(
f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records"
)
# Compare databases
print("\n4. Comparing original vs loaded...")
@@ -52,7 +69,9 @@ def test_database():
# Check vector equality
vectors_equal = True
if vectors_match and original_db.vectors:
for i, (orig, loaded) in enumerate(zip(original_db.vectors, loaded_db.vectors)):
for i, (orig, loaded) in enumerate(
zip(original_db.vectors, loaded_db.vectors)
):
if not np.array_equal(orig, loaded):
vectors_equal = False
print(f" Vector {i} mismatch!")
@@ -75,16 +94,19 @@ def test_database():
# Test embedding functionality
print("\n5. Testing embedding functionality (Ollama API server)...")
try:
test_embedding = db._embed("This is a test text for embedding.")
print(f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}")
ollama_running = True
except Exception as e:
print(f" Embedding test FAILED: {e}\n Did you start ollama docker image?")
ollama_running = False
embedding_ok = db.test_embedding()
print(f" Embedding test {'PASSED' if embedding_ok else 'FAILED'}")
if not embedding_ok:
print(" Did you start ollama docker image?")
# Summary
all_good = vectors_match and records_match and vectors_equal and records_equal and ollama_running
all_good = (
vectors_match
and records_match
and vectors_equal
and records_equal
and embedding_ok
)
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
finally:
@@ -101,7 +123,7 @@ def create_database(db_path: str):
# Check if file already exists
if db_file.exists():
response = input(f"Database {db_file} already exists. Overwrite? (y/N): ")
if response.lower() != 'y':
if response.lower() != "y":
print("Operation cancelled.")
return
@@ -166,10 +188,10 @@ def query(db_path: str, query_text: str):
for i, res in enumerate(results, 1):
print(f"\n{i}. Distance: {res.distance:.4f}")
print(f" Document: {res.document.name}")
print(f" Document: {res.document_name}")
print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
# Replace all whitespace characters with regular spaces for cleaner display
clean_text = ' '.join(res.record.text[:200].split())
clean_text = " ".join(res.record.text[:200].split())
print(f" Text preview: {clean_text}...")
if i < len(results):
print("-" * 40)
@@ -178,10 +200,11 @@ def query(db_path: str, query_text: str):
print(f"Error querying database: {e}")
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
"""Start a web server for the semantic search tool."""
try:
# here we intentionally import inside the function to avoid Flask dependency for CLI usage
# pylint: disable=import-outside-toplevel
from flask import Flask, request, jsonify, render_template, send_file
except ImportError:
print("❌ Flask not found. Please install it first:")
@@ -197,36 +220,31 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
print(" Create a database first using: python main.py create")
sys.exit(1)
@app.route('/')
@app.route("/")
def index():
return render_template("index.html", results=None)
@app.route('/file/<path:document_path>')
def serve_file(document_path):
@app.route("/file/<int:document_index>")
def serve_file(document_index):
"""Serve PDF files directly."""
try:
file_path = Path(document_path)
file_path = db.get_document_path(db_file, document_index)
if not file_path.exists():
return jsonify({'error': 'File not found'}), 404
# Check if it's a PDF file for security
if file_path.suffix.lower() != '.pdf':
return jsonify({'error': 'Only PDF files are allowed'}), 403
return jsonify({"error": "File not found"}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': str(e)}), 500
return jsonify({"error": str(e)}), 500
@app.route('/api/search', methods=['POST'])
@app.route("/api/search", methods=["POST"])
def search():
try:
data = request.get_json()
if not data or 'query' not in data:
return jsonify({'error': 'Missing query parameter'}), 400
if not data or "query" not in data:
return jsonify({"error": "Missing query parameter"}), 400
query_text = data['query'].strip()
query_text = data["query"].strip()
if not query_text:
return jsonify({'error': 'Query cannot be empty'}), 400
return jsonify({"error": "Query cannot be empty"}), 400
# Perform the search
results = db.query(db_file, query_text)
@@ -234,19 +252,22 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
# Format results for JSON response
formatted_results = []
for res in results:
formatted_results.append({
'distance': float(res.distance),
'document': res.document.name,
'document_path': str(res.document), # Full path for the link
'page': res.record.page,
'chunk': res.record.chunk,
'text': ' '.join(res.record.text[:300].split()) # Clean and truncate text
})
return jsonify({'results': formatted_results})
formatted_results.append(
{
"distance": float(res.distance),
"document_name": res.document_name,
"document_index": res.record.document_index,
"page": res.record.page,
"chunk": res.record.chunk,
"text": " ".join(
res.record.text[:300].split()
), # Clean and truncate text
}
)
return jsonify({"results": formatted_results})
except Exception as e:
return jsonify({'error': str(e)}), 500
return jsonify({"error": str(e)}), 500
print("🚀 Starting web server...")
print(f" Database: {db_file}")
@@ -262,51 +283,76 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
def main():
"""
Main function to parse command-line arguments and execute commands.
"""
parser = argparse.ArgumentParser(
description="Semantic Search Tool",
formatter_class=argparse.RawDescriptionHelpFormatter
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Create subparsers for different commands
subparsers = parser.add_subparsers(dest='command', help='Available commands')
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Create command
create_parser = subparsers.add_parser('create', aliases=['c'], help='Create a new empty database')
create_parser.add_argument('db_path', nargs='?', default=str(DEFAULT_DB_PATH),
help=f'Path to database file (default: {DEFAULT_DB_PATH})')
create_parser = subparsers.add_parser(
"create", aliases=["c"], help="Create a new empty database"
)
create_parser.add_argument(
"db_path",
nargs="?",
default=str(DEFAULT_DB_PATH),
help=f"Path to database file (default: {DEFAULT_DB_PATH})",
)
# Add file command
add_parser = subparsers.add_parser('add-file', aliases=['a'], help='Add one or more files to the search database')
add_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
add_parser.add_argument('file_paths', nargs='+', help='Path(s) to the PDF file(s) to add')
add_parser = subparsers.add_parser(
"add-file", aliases=["a"], help="Add one or more files to the search database"
)
add_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
add_parser.add_argument(
"file_paths", nargs="+", help="Path(s) to the PDF file(s) to add"
)
# Query command
query_parser = subparsers.add_parser('query', aliases=['q'], help='Query the search database')
query_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
query_parser.add_argument('query_text', help='Text to search for')
query_parser = subparsers.add_parser(
"query", aliases=["q"], help="Query the search database"
)
query_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
query_parser.add_argument("query_text", help="Text to search for")
# Host command (web server)
host_parser = subparsers.add_parser('host', aliases=['h'], help='Start a web server for semantic search')
host_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
host_parser.add_argument('--host', default='127.0.0.1', help='Host address to bind to (default: 127.0.0.1)')
host_parser.add_argument('--port', type=int, default=5000, help='Port to listen on (default: 5000)')
host_parser = subparsers.add_parser(
"host", aliases=["h"], help="Start a web server for semantic search"
)
host_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
host_parser.add_argument(
"--host",
default="127.0.0.1",
help="Host address to bind to (default: 127.0.0.1)",
)
host_parser.add_argument(
"--port", type=int, default=5000, help="Port to listen on (default: 5000)"
)
# Test command
subparsers.add_parser('test', aliases=['t'], help='Test database save/load functionality')
subparsers.add_parser(
"test", aliases=["t"], help="Test database save/load functionality"
)
# Parse arguments
args = parser.parse_args()
# Handle commands
if args.command in ['create', 'c']:
if args.command in ["create", "c"]:
create_database(args.db_path)
elif args.command in ['add-file', 'a']:
elif args.command in ["add-file", "a"]:
add_file(args.db, args.file_paths)
elif args.command in ['query', 'q']:
elif args.command in ["query", "q"]:
query(args.db, args.query_text)
elif args.command in ['host', 'h']:
elif args.command in ["host", "h"]:
start_web_server(args.db, args.host, args.port)
elif args.command in ['test', 't']:
elif args.command in ["test", "t"]:
test_database()
else:
parser.print_help()

View File

@@ -58,7 +58,7 @@
resultsDiv.innerHTML = data.results.map((result, i) => `
<div class="result">
<div class="result-header">
Result ${i + 1} - <a href="/file/${encodeURIComponent(result.document_path)}#page=${result.page}" class="document-link" target="_blank">${result.document}</a>
Result ${i + 1} - <a href="/file/${encodeURIComponent(result.document_index)}#page=${result.page}" class="document-link" target="_blank">${result.document_name}</a>
<span class="distance">(Distance: ${result.distance.toFixed(4)})</span>
</div>
<div>Page: ${result.page}, Chunk: ${result.chunk}</div>