Compare commits

..

4 Commits

Author SHA1 Message Date
Jan Mrna
ee8a8ad170 Fixed warnings from pylint 2025-11-06 10:58:18 +01:00
Jan Mrna
7010edae44 Format document 2025-11-06 10:46:54 +01:00
Jan Mrna
e734a13a59 Change line endings to LF 2025-11-06 10:46:31 +01:00
Jan Mrna
788eebc916 Serve by file index, not full path 2025-11-06 10:45:42 +01:00
3 changed files with 562 additions and 462 deletions

604
db.py
View File

@@ -1,275 +1,329 @@
import pickle #pylint: disable=missing-class-docstring,invalid-name,broad-exception-caught
from pathlib import Path """
from dataclasses import dataclass Database module for semantic document search tool.
from typing import Final """
from concurrent.futures import ThreadPoolExecutor, as_completed import pickle
from pathlib import Path
import numpy as np from dataclasses import dataclass
import pymupdf from typing import Final
import ollama # TODO split to another file from concurrent.futures import ThreadPoolExecutor, as_completed
# import numpy as np
# Types import pymupdf
# import ollama
type Vector = np.NDArray # np.NDArray[np.float32] ? #
type VectorBytes = bytes # Types
#
@dataclass(slots=True)
class Record: type Vector = np.NDArray
document_index: int type VectorBytes = bytes
page: int
text: str
chunk: int = 0 # Chunk number within the page (0-indexed) @dataclass(slots=True)
class Record:
@dataclass(slots=True) document_index: int
class QueryResult: page: int
record: Record text: str
distance: float chunk: int = 0 # Chunk number within the page (0-indexed)
document: Path
@dataclass(slots=True) @dataclass(slots=True)
class Database: class QueryResult:
""" record: Record
"Vectors" hold the data for fast lookup of the vector, distance: float
which can then be used to obtain record. document_name: str
TODO For faster nearest neighbour lookup we should use something else,
e.g. kd-trees
""" @dataclass(slots=True)
vectors: list[Vector] class Database:
records: dict[VectorBytes, Record] """
documents: list[Path] "Vectors" hold the data for fast lookup of the vector,
which can then be used to obtain record.
TODO For faster nearest neighbour lookup we should use something else,
# e.g. kd-trees
# Internal functions """
#
vectors: list[Vector]
records: dict[VectorBytes, Record]
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]: documents: list[Path]
"""
Find the N nearest vectors to the query embedding.
#
Args: # Internal functions
vectors_db: List of vectors in the database #
query_vector: Query embedding vector
n: Number of nearest neighbors to return
def _find_nearest(
Returns: vectors_db: list[Vector], query_vector: Vector, count: int = 10
List of (distance, index) tuples, sorted by distance (closest first) ) -> list[tuple[float, int]]:
""" """
distances = [np.linalg.norm(x - query_vector) for x in vectors_db] Find the N nearest vectors to the query embedding.
# Get indices sorted by distance Args:
sorted_indices = np.argsort(distances) vectors_db: List of vectors in the database
query_vector: Query embedding vector
# Return top N results as (distance, index) tuples n: Number of nearest neighbors to return
results = []
for i in range(min(count, len(sorted_indices))): Returns:
idx = sorted_indices[i] List of (distance, index) tuples, sorted by distance (closest first)
results.append((float(distances[idx]), int(idx))) """
distances = [np.linalg.norm(x - query_vector) for x in vectors_db]
return results
# Get indices sorted by distance
def _embed(text: str) -> Vector: sorted_indices = np.argsort(distances)
"""
Generate embedding vector for given text. # Return top N results as (distance, index) tuples
""" results = []
MODEL: Final[str] = "nomic-embed-text" for i in range(min(count, len(sorted_indices))):
return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"]) idx = sorted_indices[i]
results.append((float(distances[idx]), int(idx)))
def _vectorize_record(record: Record) -> tuple[Record, Vector]: return results
return record, _embed(record.text)
# def _embed(text: str) -> Vector:
# High-level (exported) functions """
# Generate embedding vector for given text.
"""
def create_dummy() -> Database: MODEL: Final[str] = "nomic-embed-text"
db_length: Final[int] = 10 return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"])
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
records = {
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors def _vectorize_record(record: Record) -> tuple[Record, Vector]:
} return record, _embed(record.text)
return Database(vectors, records, [Path("dummy")])
def test_embedding() -> bool:
def create_empty() -> Database: """
""" Test if embedding functionality is available and working.
Creates a new empty database with no vectors or records.
Returns:
Returns: bool: True if embedding is working, False otherwise
Empty Database object """
""" try:
return Database(vectors=[], records={}, documents=[]) _ = _embed("Test.")
return True
except Exception:
def load(database_file: Path) -> Database: return False
"""
Loads a database from the given file.
#
Args: # High-level (exported) functions
database_file: Path to the database file #
Returns:
Database object loaded from file def create_dummy() -> Database:
""" """
if not database_file.exists(): Create a dummy database for testing purposes.
raise FileNotFoundError(f"Database file not found: {database_file}") """
db_length: Final[int] = 10
with open(database_file, 'rb') as f: vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
serializable_db = pickle.load(f) records = {
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
# Reconstruct vectors from bytes }
vectors = [] return Database(vectors, records, [Path("dummy")])
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
vector_shape = serializable_db.get('vector_shape', ())
def create_empty() -> Database:
for vector_bytes in serializable_db['vectors']: """
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape) Creates a new empty database with no vectors or records.
vectors.append(vector)
Returns:
# Records already use bytes as keys, so we can use them directly Empty Database object
records = serializable_db['records'] """
return Database(vectors=[], records={}, documents=[])
documents = serializable_db['documents']
return Database(vectors, records, documents) def load(database_file: Path) -> Database:
"""
def save(db: Database, database_file: Path) -> None: Loads a database from the given file.
"""
Saves the database to a file using pickle serialization. Args:
database_file: Path to the database file
Args:
db: The Database object to save Returns:
database_file: Path where to save the database file Database object loaded from file
""" """
# Ensure the directory exists if not database_file.exists():
database_file.parent.mkdir(parents=True, exist_ok=True) raise FileNotFoundError(f"Database file not found: {database_file}")
# Create a serializable version of the database with open(database_file, "rb") as f:
# Records already use bytes as keys, so we can use them directly serializable_db = pickle.load(f)
serializable_db = {
'vectors': [vector.tobytes() for vector in db.vectors], # Reconstruct vectors from bytes
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64', vectors = []
'vector_shape': db.vectors[0].shape if db.vectors else (), vector_dtype = np.dtype(serializable_db.get("vector_dtype", "float64"))
'records': db.records, # Already uses bytes as keys vector_shape = serializable_db.get("vector_shape", ())
'documents': db.documents,
} for vector_bytes in serializable_db["vectors"]:
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
# Save to file vectors.append(vector)
with open(database_file, 'wb') as f:
pickle.dump(serializable_db, f) # Records already use bytes as keys, so we can use them directly
records = serializable_db["records"]
def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]: documents = serializable_db["documents"]
"""
Query the database and return the N nearest records. return Database(vectors, records, documents)
Args:
db: Database object or path to database file def save(db: Database, database_file: Path) -> None:
text: Query text to search for """
record_count: Number of nearest neighbors to return (default: 10) Saves the database to a file using pickle serialization.
Returns: Args:
List of (distance, Record) tuples, sorted by distance (closest first) db: The Database object to save
""" database_file: Path where to save the database file
if isinstance(db, Path): """
db = load(db) # Ensure the directory exists
database_file.parent.mkdir(parents=True, exist_ok=True)
# Generate embedding for query text
query_vector = _embed(text) # Create a serializable version of the database
# Records already use bytes as keys, so we can use them directly
# Find nearest vectors serializable_db = {
# NOTE We're using euclidean distance as a metric of similarity, "vectors": [vector.tobytes() for vector in db.vectors],
# there are some alternatives (cos or dot product) which may be used. "vector_dtype": str(db.vectors[0].dtype) if db.vectors else "float64",
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning) "vector_shape": db.vectors[0].shape if db.vectors else (),
nearest_results = _find_nearest(db.vectors, query_vector, record_count) "records": db.records, # Already uses bytes as keys
"documents": db.documents,
# Convert results to (distance, Record) tuples }
results: list[QueryResult] = []
for distance, vector_idx in nearest_results: # Save to file
# Get the vector at this index with open(database_file, "wb") as f:
vector = db.vectors[vector_idx] pickle.dump(serializable_db, f)
vector_bytes = vector.tobytes()
# Look up the corresponding record def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]:
if vector_bytes in db.records: """
record = db.records[vector_bytes] Query the database and return the N nearest records.
results.append(QueryResult(record, distance, db.documents[record.document_index]))
Args:
return results db: Database object or path to database file
text: Query text to search for
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: record_count: Number of nearest neighbors to return (default: 10)
"""
Adds a new document to the database. If path is given, do load, add, save. Returns:
Loads PDF with PyMuPDF, splits by pages, and creates records and vectors. List of (distance, Record) tuples, sorted by distance (closest first)
Uses multithreading for embedding generation. """
if isinstance(db, Path):
Args: db = load(db)
db: Database object or path to database file
file: Path to PDF file to add # Generate embedding for query text
max_workers: Maximum number of threads for parallel processing query_vector = _embed(text)
""" # Find nearest vectors
save_to_file = False # NOTE We're using euclidean distance as a metric of similarity,
database_file_path = None # there are some alternatives (cos or dot product) which may be used.
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning)
if isinstance(db, Path): nearest_results = _find_nearest(db.vectors, query_vector, record_count)
database_file_path = db
db = load(db) # Convert results to (distance, Record) tuples
save_to_file = True results: list[QueryResult] = []
for distance, vector_idx in nearest_results:
if not file.exists(): # Get the vector at this index
raise FileNotFoundError(f"File not found: {file}") vector = db.vectors[vector_idx]
vector_bytes = vector.tobytes()
if file.suffix.lower() != '.pdf':
raise ValueError(f"File must be a PDF: {file}") # Look up the corresponding record
if vector_bytes in db.records:
print(f"Processing PDF: {file}") record = db.records[vector_bytes]
results.append(
document_index = len(db.documents) QueryResult(record, distance, db.documents[record.document_index].name)
try: )
doc = pymupdf.open(file)
print(f"PDF opened successfully: {len(doc)} pages") return results
records: list[Record] = []
chunk_size = 1024 def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
"""
for page_num in range(len(doc)): Adds a new document to the database. If path is given, do load, add, save.
page = doc[page_num] Loads PDF with PyMuPDF, splits by pages, and creates records and vectors.
text = page.get_text().strip() Uses multithreading for embedding generation.
if not text:
print(f" Page {page_num + 1}: Skipped (empty)") Args:
continue db: Database object or path to database file
file: Path to PDF file to add
# Simple chunking - split text into chunks of specified size max_workers: Maximum number of threads for parallel processing
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
chunk = text[i:i + chunk_size] """
if chunk_stripped := chunk.strip(): # Only add non-empty chunks save_to_file = False
# page_num + 1 for use friendliness database_file_path = None
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
doc.close() if isinstance(db, Path):
except Exception as e: database_file_path = db
raise RuntimeError(f"Error processing PDF {file}: {e}") db = load(db)
save_to_file = True
# Process chunks in parallel
print(f"Processing {len(records)} chunks with {max_workers} workers...") if not file.exists():
raise FileNotFoundError(f"File not found: {file}")
db.documents.append(file)
if file.suffix.lower() != ".pdf":
# TODO measure with GIL disabled to check if multithreading actually helps raise ValueError(f"File must be a PDF: {file}")
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [ pool.submit(_vectorize_record, r) for r in records ] print(f"Processing PDF: {file}")
for f in as_completed(futures):
record, vector = f.result() document_index = len(db.documents)
db.records[vector.tobytes()] = record try:
db.vectors.append(vector) doc = pymupdf.open(file)
print(f"PDF opened successfully: {len(doc)} pages")
print(f"Successfully processed {file}: {len(records)} chunks")
records: list[Record] = []
# Save database if we loaded it from file chunk_size = 1024
if save_to_file and database_file_path:
save(db, database_file_path) for page_num, page in enumerate(doc):
print(f"Database saved to {database_file_path}") text = page.get_text().strip()
if not text:
print(f" Page {page_num + 1}: Skipped (empty)")
continue
# Simple chunking - split text into chunks of specified size
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
chunk = text[i : i + chunk_size]
if chunk_stripped := chunk.strip(): # Only add non-empty chunks
# page_num + 1 for use friendliness
records.append(
Record(document_index, page_num + 1, chunk_stripped, chunk_idx)
)
doc.close()
except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}") from e
# Process chunks in parallel
print(f"Processing {len(records)} chunks with {max_workers} workers...")
db.documents.append(file)
# NOTE this will only help with GIL disabled
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_vectorize_record, r) for r in records]
for f in as_completed(futures):
record, vector = f.result()
db.records[vector.tobytes()] = record
db.vectors.append(vector)
print(f"Successfully processed {file}: {len(records)} chunks")
# Save database if we loaded it from file
if save_to_file and database_file_path:
save(db, database_file_path)
print(f"Database saved to {database_file_path}")
def get_document_path(db: Database | Path, document_index: int) -> Path:
"""
Get the file path of the document at the given index in the database.
Args:
db: Database object or path to database file
document_index: Index of the document to retrieve
Returns:
Path to the document file
"""
if isinstance(db, Path):
db = load(db)
if document_index < 0 or document_index >= len(db.documents):
raise IndexError(f"Document index out of range: {document_index}")
return db.documents[document_index]

272
main.py
View File

@@ -1,64 +1,83 @@
import argparse #pylint: disable=broad-exception-caught
"""
Semantic search tool main script.
Provides command-line interface and web server for creating, adding and querying the database.
"""
import sys import sys
from pathlib import Path
import tempfile import tempfile
import numpy as np import argparse
from typing import Final from typing import Final
from pathlib import Path
import numpy as np
import db import db
DEFAULT_DB_PATH: Final[Path] = Path("db.pkl") DEFAULT_DB_PATH: Final[Path] = Path("db.pkl")
def test_database(): def test_database():
"""Test database save/load functionality by creating, saving, loading and comparing.""" """Test database save/load functionality by creating, saving, loading and comparing."""
print("=== Database Test ===") print("=== Database Test ===")
# Create dummy database # Create dummy database
print("1. Creating dummy database...") print("1. Creating dummy database...")
original_db = db.create_dummy() original_db = db.create_dummy()
print(f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records") print(
f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records"
)
# Print some details about the original database # Print some details about the original database
print(" First vector shape:", original_db.vectors[0].shape if original_db.vectors else "No vectors") print(
print(" Sample vector:", original_db.vectors[0][:4] if original_db.vectors else "No vectors") " First vector shape:",
original_db.vectors[0].shape if original_db.vectors else "No vectors",
)
print(
" Sample vector:",
original_db.vectors[0][:4] if original_db.vectors else "No vectors",
)
print(" Sample record keys (first 3):", list(original_db.records.keys())[:3]) print(" Sample record keys (first 3):", list(original_db.records.keys())[:3])
# Create temporary file for testing # Create temporary file for testing
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file: with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file:
test_file = Path(tmp_file.name) test_file = Path(tmp_file.name)
try: try:
# Save database # Save database
print(f"\n2. Saving database to {test_file}...") print(f"\n2. Saving database to {test_file}...")
db.save(original_db, test_file) db.save(original_db, test_file)
print(f" File size: {test_file.stat().st_size} bytes") print(f" File size: {test_file.stat().st_size} bytes")
# Load database # Load database
print(f"\n3. Loading database from {test_file}...") print(f"\n3. Loading database from {test_file}...")
loaded_db = db.load(test_file) loaded_db = db.load(test_file)
print(f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records") print(
f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records"
)
# Compare databases # Compare databases
print("\n4. Comparing original vs loaded...") print("\n4. Comparing original vs loaded...")
# Check vector count # Check vector count
vectors_match = len(original_db.vectors) == len(loaded_db.vectors) vectors_match = len(original_db.vectors) == len(loaded_db.vectors)
print(f" Vector count match: {vectors_match}") print(f" Vector count match: {vectors_match}")
# Check record count # Check record count
records_match = len(original_db.records) == len(loaded_db.records) records_match = len(original_db.records) == len(loaded_db.records)
print(f" Record count match: {records_match}") print(f" Record count match: {records_match}")
# Check vector equality # Check vector equality
vectors_equal = True vectors_equal = True
if vectors_match and original_db.vectors: if vectors_match and original_db.vectors:
for i, (orig, loaded) in enumerate(zip(original_db.vectors, loaded_db.vectors)): for i, (orig, loaded) in enumerate(
zip(original_db.vectors, loaded_db.vectors)
):
if not np.array_equal(orig, loaded): if not np.array_equal(orig, loaded):
vectors_equal = False vectors_equal = False
print(f" Vector {i} mismatch!") print(f" Vector {i} mismatch!")
break break
print(f" All vectors equal: {vectors_equal}") print(f" All vectors equal: {vectors_equal}")
# Check record equality # Check record equality
records_equal = True records_equal = True
if records_match: if records_match:
@@ -72,21 +91,24 @@ def test_database():
print(" Record content mismatch!") print(" Record content mismatch!")
break break
print(f" All records equal: {records_equal}") print(f" All records equal: {records_equal}")
# Test embedding functionality # Test embedding functionality
print("\n5. Testing embedding functionality (Ollama API server)...") print("\n5. Testing embedding functionality (Ollama API server)...")
try: embedding_ok = db.test_embedding()
test_embedding = db._embed("This is a test text for embedding.") print(f" Embedding test {'PASSED' if embedding_ok else 'FAILED'}")
print(f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}") if not embedding_ok:
ollama_running = True print(" Did you start ollama docker image?")
except Exception as e:
print(f" Embedding test FAILED: {e}\n Did you start ollama docker image?")
ollama_running = False
# Summary # Summary
all_good = vectors_match and records_match and vectors_equal and records_equal and ollama_running all_good = (
vectors_match
and records_match
and vectors_equal
and records_equal
and embedding_ok
)
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}") print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
finally: finally:
# Clean up temporary file # Clean up temporary file
if test_file.exists(): if test_file.exists():
@@ -97,20 +119,20 @@ def test_database():
def create_database(db_path: str): def create_database(db_path: str):
"""Create a new empty database.""" """Create a new empty database."""
db_file = Path(db_path) db_file = Path(db_path)
# Check if file already exists # Check if file already exists
if db_file.exists(): if db_file.exists():
response = input(f"Database {db_file} already exists. Overwrite? (y/N): ") response = input(f"Database {db_file} already exists. Overwrite? (y/N): ")
if response.lower() != 'y': if response.lower() != "y":
print("Operation cancelled.") print("Operation cancelled.")
return return
# Create empty database # Create empty database
empty_db = db.create_empty() empty_db = db.create_empty()
# Save to file # Save to file
db.save(empty_db, db_file) db.save(empty_db, db_file)
print(f"✅ Created empty database: {db_file}") print(f"✅ Created empty database: {db_file}")
print(f" Vectors: {len(empty_db.vectors)}") print(f" Vectors: {len(empty_db.vectors)}")
print(f" Records: {len(empty_db.records)}") print(f" Records: {len(empty_db.records)}")
@@ -119,11 +141,11 @@ def create_database(db_path: str):
def add_file(db_path: str, file_paths: list[str]): def add_file(db_path: str, file_paths: list[str]):
"""Add one or more files to the semantic search database.""" """Add one or more files to the semantic search database."""
print(f"Adding {len(file_paths)} file(s) to database: {db_path}") print(f"Adding {len(file_paths)} file(s) to database: {db_path}")
db_file = Path(db_path) db_file = Path(db_path)
successful_files = [] successful_files = []
failed_files = [] failed_files = []
for i, file_path in enumerate(file_paths, 1): for i, file_path in enumerate(file_paths, 1):
print(f"\n[{i}/{len(file_paths)}] Processing: {file_path}") print(f"\n[{i}/{len(file_paths)}] Processing: {file_path}")
try: try:
@@ -133,7 +155,7 @@ def add_file(db_path: str, file_paths: list[str]):
except Exception as e: except Exception as e:
failed_files.append((file_path, str(e))) failed_files.append((file_path, str(e)))
print(f"❌ Failed to add {file_path}: {e}") print(f"❌ Failed to add {file_path}: {e}")
# Summary # Summary
print(f"\n{'='*60}") print(f"\n{'='*60}")
print("SUMMARY:") print("SUMMARY:")
@@ -141,47 +163,48 @@ def add_file(db_path: str, file_paths: list[str]):
if successful_files: if successful_files:
for file_path in successful_files: for file_path in successful_files:
print(f" - {Path(file_path).name}") print(f" - {Path(file_path).name}")
if failed_files: if failed_files:
print(f"❌ Failed to add: {len(failed_files)} files") print(f"❌ Failed to add: {len(failed_files)} files")
for file_path, error in failed_files: for file_path, error in failed_files:
print(f" - {Path(file_path).name}: {error}") print(f" - {Path(file_path).name}: {error}")
print(f"{'='*60}") print(f"{'='*60}")
def query(db_path: str, query_text: str): def query(db_path: str, query_text: str):
"""Query the semantic search database.""" """Query the semantic search database."""
print(f"Querying: '{query_text}' in database: {db_path}") print(f"Querying: '{query_text}' in database: {db_path}")
try: try:
results = db.query(Path(db_path), query_text) results = db.query(Path(db_path), query_text)
if not results: if not results:
print("No results found.") print("No results found.")
return return
print(f"\nFound {len(results)} results:") print(f"\nFound {len(results)} results:")
print("=" * 60) print("=" * 60)
for i, res in enumerate(results, 1): for i, res in enumerate(results, 1):
print(f"\n{i}. Distance: {res.distance:.4f}") print(f"\n{i}. Distance: {res.distance:.4f}")
print(f" Document: {res.document.name}") print(f" Document: {res.document_name}")
print(f" Page: {res.record.page}, Chunk: {res.record.chunk}") print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
# Replace all whitespace characters with regular spaces for cleaner display # Replace all whitespace characters with regular spaces for cleaner display
clean_text = ' '.join(res.record.text[:200].split()) clean_text = " ".join(res.record.text[:200].split())
print(f" Text preview: {clean_text}...") print(f" Text preview: {clean_text}...")
if i < len(results): if i < len(results):
print("-" * 40) print("-" * 40)
except Exception as e: except Exception as e:
print(f"Error querying database: {e}") print(f"Error querying database: {e}")
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000): def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
"""Start a web server for the semantic search tool.""" """Start a web server for the semantic search tool."""
try: try:
# here we intentionally import inside the function to avoid Flask dependency for CLI usage
# pylint: disable=import-outside-toplevel
from flask import Flask, request, jsonify, render_template, send_file from flask import Flask, request, jsonify, render_template, send_file
except ImportError: except ImportError:
print("❌ Flask not found. Please install it first:") print("❌ Flask not found. Please install it first:")
@@ -190,69 +213,67 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
# Set template_folder to 'templates' directory # Set template_folder to 'templates' directory
app = Flask(__name__, template_folder="templates") app = Flask(__name__, template_folder="templates")
db_file = Path(db_path) db_file = Path(db_path)
# Check if database exists # Check if database exists
if not db_file.exists(): if not db_file.exists():
print(f"❌ Database file not found: {db_file}") print(f"❌ Database file not found: {db_file}")
print(" Create a database first using: python main.py create") print(" Create a database first using: python main.py create")
sys.exit(1) sys.exit(1)
@app.route('/') @app.route("/")
def index(): def index():
return render_template("index.html", results=None) return render_template("index.html", results=None)
@app.route('/file/<path:document_path>') @app.route("/file/<int:document_index>")
def serve_file(document_path): def serve_file(document_index):
"""Serve PDF files directly.""" """Serve PDF files directly."""
try: try:
file_path = Path(document_path) file_path = db.get_document_path(db_file, document_index)
if not file_path.exists(): if not file_path.exists():
return jsonify({'error': 'File not found'}), 404 return jsonify({"error": "File not found"}), 404
# Check if it's a PDF file for security
if file_path.suffix.lower() != '.pdf':
return jsonify({'error': 'Only PDF files are allowed'}), 403
return send_file(file_path, as_attachment=False) return send_file(file_path, as_attachment=False)
except Exception as e: except Exception as e:
return jsonify({'error': str(e)}), 500 return jsonify({"error": str(e)}), 500
@app.route('/api/search', methods=['POST']) @app.route("/api/search", methods=["POST"])
def search(): def search():
try: try:
data = request.get_json() data = request.get_json()
if not data or 'query' not in data: if not data or "query" not in data:
return jsonify({'error': 'Missing query parameter'}), 400 return jsonify({"error": "Missing query parameter"}), 400
query_text = data['query'].strip() query_text = data["query"].strip()
if not query_text: if not query_text:
return jsonify({'error': 'Query cannot be empty'}), 400 return jsonify({"error": "Query cannot be empty"}), 400
# Perform the search # Perform the search
results = db.query(db_file, query_text) results = db.query(db_file, query_text)
# Format results for JSON response # Format results for JSON response
formatted_results = [] formatted_results = []
for res in results: for res in results:
formatted_results.append({ formatted_results.append(
'distance': float(res.distance), {
'document': res.document.name, "distance": float(res.distance),
'document_path': str(res.document), # Full path for the link "document_name": res.document_name,
'page': res.record.page, "document_index": res.record.document_index,
'chunk': res.record.chunk, "page": res.record.page,
'text': ' '.join(res.record.text[:300].split()) # Clean and truncate text "chunk": res.record.chunk,
}) "text": " ".join(
res.record.text[:300].split()
return jsonify({'results': formatted_results}) ), # Clean and truncate text
}
)
return jsonify({"results": formatted_results})
except Exception as e: except Exception as e:
return jsonify({'error': str(e)}), 500 return jsonify({"error": str(e)}), 500
print("🚀 Starting web server...") print("🚀 Starting web server...")
print(f" Database: {db_file}") print(f" Database: {db_file}")
print(f" URL: http://{host}:{port}") print(f" URL: http://{host}:{port}")
print(" Press Ctrl+C to stop") print(" Press Ctrl+C to stop")
try: try:
app.run(host=host, port=port, debug=False) app.run(host=host, port=port, debug=False)
except KeyboardInterrupt: except KeyboardInterrupt:
@@ -262,51 +283,76 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
def main(): def main():
"""
Main function to parse command-line arguments and execute commands.
"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Semantic Search Tool", description="Semantic Search Tool",
formatter_class=argparse.RawDescriptionHelpFormatter formatter_class=argparse.RawDescriptionHelpFormatter,
) )
# Create subparsers for different commands # Create subparsers for different commands
subparsers = parser.add_subparsers(dest='command', help='Available commands') subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Create command # Create command
create_parser = subparsers.add_parser('create', aliases=['c'], help='Create a new empty database') create_parser = subparsers.add_parser(
create_parser.add_argument('db_path', nargs='?', default=str(DEFAULT_DB_PATH), "create", aliases=["c"], help="Create a new empty database"
help=f'Path to database file (default: {DEFAULT_DB_PATH})') )
create_parser.add_argument(
"db_path",
nargs="?",
default=str(DEFAULT_DB_PATH),
help=f"Path to database file (default: {DEFAULT_DB_PATH})",
)
# Add file command # Add file command
add_parser = subparsers.add_parser('add-file', aliases=['a'], help='Add one or more files to the search database') add_parser = subparsers.add_parser(
add_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') "add-file", aliases=["a"], help="Add one or more files to the search database"
add_parser.add_argument('file_paths', nargs='+', help='Path(s) to the PDF file(s) to add') )
add_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
add_parser.add_argument(
"file_paths", nargs="+", help="Path(s) to the PDF file(s) to add"
)
# Query command # Query command
query_parser = subparsers.add_parser('query', aliases=['q'], help='Query the search database') query_parser = subparsers.add_parser(
query_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') "query", aliases=["q"], help="Query the search database"
query_parser.add_argument('query_text', help='Text to search for') )
query_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
query_parser.add_argument("query_text", help="Text to search for")
# Host command (web server) # Host command (web server)
host_parser = subparsers.add_parser('host', aliases=['h'], help='Start a web server for semantic search') host_parser = subparsers.add_parser(
host_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') "host", aliases=["h"], help="Start a web server for semantic search"
host_parser.add_argument('--host', default='127.0.0.1', help='Host address to bind to (default: 127.0.0.1)') )
host_parser.add_argument('--port', type=int, default=5000, help='Port to listen on (default: 5000)') host_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
host_parser.add_argument(
"--host",
default="127.0.0.1",
help="Host address to bind to (default: 127.0.0.1)",
)
host_parser.add_argument(
"--port", type=int, default=5000, help="Port to listen on (default: 5000)"
)
# Test command # Test command
subparsers.add_parser('test', aliases=['t'], help='Test database save/load functionality') subparsers.add_parser(
"test", aliases=["t"], help="Test database save/load functionality"
)
# Parse arguments # Parse arguments
args = parser.parse_args() args = parser.parse_args()
# Handle commands # Handle commands
if args.command in ['create', 'c']: if args.command in ["create", "c"]:
create_database(args.db_path) create_database(args.db_path)
elif args.command in ['add-file', 'a']: elif args.command in ["add-file", "a"]:
add_file(args.db, args.file_paths) add_file(args.db, args.file_paths)
elif args.command in ['query', 'q']: elif args.command in ["query", "q"]:
query(args.db, args.query_text) query(args.db, args.query_text)
elif args.command in ['host', 'h']: elif args.command in ["host", "h"]:
start_web_server(args.db, args.host, args.port) start_web_server(args.db, args.host, args.port)
elif args.command in ['test', 't']: elif args.command in ["test", "t"]:
test_database() test_database()
else: else:
parser.print_help() parser.print_help()

View File

@@ -1,75 +1,75 @@
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
<title>Semantic Document Search</title> <title>Semantic Document Search</title>
<style> <style>
body { font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; } body { font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }
.search-box { margin-bottom: 20px; } .search-box { margin-bottom: 20px; }
input[type="text"] { width: 70%; padding: 10px; font-size: 16px; } input[type="text"] { width: 70%; padding: 10px; font-size: 16px; }
button { padding: 10px 20px; font-size: 16px; background: #007cba; color: white; border: none; cursor: pointer; } button { padding: 10px 20px; font-size: 16px; background: #007cba; color: white; border: none; cursor: pointer; }
button:hover { background: #005c8a; } button:hover { background: #005c8a; }
.result { border: 1px solid #ddd; margin: 10px 0; padding: 15px; border-radius: 5px; } .result { border: 1px solid #ddd; margin: 10px 0; padding: 15px; border-radius: 5px; }
.result-header { font-weight: bold; color: #333; margin-bottom: 10px; } .result-header { font-weight: bold; color: #333; margin-bottom: 10px; }
.result-text { background: #f9f9f9; padding: 10px; border-radius: 3px; } .result-text { background: #f9f9f9; padding: 10px; border-radius: 3px; }
.distance { color: #666; font-size: 0.9em; } .distance { color: #666; font-size: 0.9em; }
.document-link { color: #007cba; text-decoration: none; } .document-link { color: #007cba; text-decoration: none; }
.document-link:hover { text-decoration: underline; } .document-link:hover { text-decoration: underline; }
.no-results { text-align: center; color: #666; margin: 40px 0; } .no-results { text-align: center; color: #666; margin: 40px 0; }
.loading { text-align: center; color: #007cba; margin: 20px 0; } .loading { text-align: center; color: #007cba; margin: 20px 0; }
</style> </style>
</head> </head>
<body> <body>
<h1>🔍 Semantic Document Search</h1> <h1>🔍 Semantic Document Search</h1>
<div class="search-box"> <div class="search-box">
<form id="searchForm"> <form id="searchForm">
<input type="text" id="queryInput" placeholder="Enter your search query..." required> <input type="text" id="queryInput" placeholder="Enter your search query..." required>
<button type="submit">Search</button> <button type="submit">Search</button>
</form> </form>
</div> </div>
<div id="results"></div> <div id="results"></div>
<script> <script>
document.getElementById('searchForm').addEventListener('submit', async (e) => { document.getElementById('searchForm').addEventListener('submit', async (e) => {
e.preventDefault(); e.preventDefault();
const query = document.getElementById('queryInput').value; const query = document.getElementById('queryInput').value;
const resultsDiv = document.getElementById('results'); const resultsDiv = document.getElementById('results');
resultsDiv.innerHTML = '<div class="loading">Searching...</div>'; resultsDiv.innerHTML = '<div class="loading">Searching...</div>';
try { try {
const response = await fetch('/api/search', { const response = await fetch('/api/search', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ query: query }) body: JSON.stringify({ query: query })
}); });
const data = await response.json(); const data = await response.json();
if (data.error) { if (data.error) {
resultsDiv.innerHTML = `<div class="no-results">Error: ${data.error}</div>`; resultsDiv.innerHTML = `<div class="no-results">Error: ${data.error}</div>`;
return; return;
} }
if (data.results.length === 0) { if (data.results.length === 0) {
resultsDiv.innerHTML = '<div class="no-results">No results found.</div>'; resultsDiv.innerHTML = '<div class="no-results">No results found.</div>';
return; return;
} }
resultsDiv.innerHTML = data.results.map((result, i) => ` resultsDiv.innerHTML = data.results.map((result, i) => `
<div class="result"> <div class="result">
<div class="result-header"> <div class="result-header">
Result ${i + 1} - <a href="/file/${encodeURIComponent(result.document_path)}#page=${result.page}" class="document-link" target="_blank">${result.document}</a> Result ${i + 1} - <a href="/file/${encodeURIComponent(result.document_index)}#page=${result.page}" class="document-link" target="_blank">${result.document_name}</a>
<span class="distance">(Distance: ${result.distance.toFixed(4)})</span> <span class="distance">(Distance: ${result.distance.toFixed(4)})</span>
</div> </div>
<div>Page: ${result.page}, Chunk: ${result.chunk}</div> <div>Page: ${result.page}, Chunk: ${result.chunk}</div>
<div class="result-text">${result.text}</div> <div class="result-text">${result.text}</div>
</div> </div>
`).join(''); `).join('');
} catch (error) { } catch (error) {
resultsDiv.innerHTML = `<div class="no-results">Error: ${error.message}</div>`; resultsDiv.innerHTML = `<div class="no-results">Error: ${error.message}</div>`;
} }
}); });
</script> </script>
</body> </body>
</html> </html>