Files
semantic_doc_search/main.py
2025-11-06 10:58:18 +01:00

364 lines
12 KiB
Python

#pylint: disable=broad-exception-caught
"""
Semantic search tool main script.
Provides command-line interface and web server for creating, adding and querying the database.
"""
import sys
import tempfile
import argparse
from typing import Final
from pathlib import Path
import numpy as np
import db
DEFAULT_DB_PATH: Final[Path] = Path("db.pkl")
def test_database():
"""Test database save/load functionality by creating, saving, loading and comparing."""
print("=== Database Test ===")
# Create dummy database
print("1. Creating dummy database...")
original_db = db.create_dummy()
print(
f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records"
)
# Print some details about the original database
print(
" First vector shape:",
original_db.vectors[0].shape if original_db.vectors else "No vectors",
)
print(
" Sample vector:",
original_db.vectors[0][:4] if original_db.vectors else "No vectors",
)
print(" Sample record keys (first 3):", list(original_db.records.keys())[:3])
# Create temporary file for testing
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file:
test_file = Path(tmp_file.name)
try:
# Save database
print(f"\n2. Saving database to {test_file}...")
db.save(original_db, test_file)
print(f" File size: {test_file.stat().st_size} bytes")
# Load database
print(f"\n3. Loading database from {test_file}...")
loaded_db = db.load(test_file)
print(
f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records"
)
# Compare databases
print("\n4. Comparing original vs loaded...")
# Check vector count
vectors_match = len(original_db.vectors) == len(loaded_db.vectors)
print(f" Vector count match: {vectors_match}")
# Check record count
records_match = len(original_db.records) == len(loaded_db.records)
print(f" Record count match: {records_match}")
# Check vector equality
vectors_equal = True
if vectors_match and original_db.vectors:
for i, (orig, loaded) in enumerate(
zip(original_db.vectors, loaded_db.vectors)
):
if not np.array_equal(orig, loaded):
vectors_equal = False
print(f" Vector {i} mismatch!")
break
print(f" All vectors equal: {vectors_equal}")
# Check record equality
records_equal = True
if records_match:
for key in original_db.records:
if key not in loaded_db.records:
records_equal = False
print(" Missing record key!")
break
if original_db.records[key] != loaded_db.records[key]:
records_equal = False
print(" Record content mismatch!")
break
print(f" All records equal: {records_equal}")
# Test embedding functionality
print("\n5. Testing embedding functionality (Ollama API server)...")
embedding_ok = db.test_embedding()
print(f" Embedding test {'PASSED' if embedding_ok else 'FAILED'}")
if not embedding_ok:
print(" Did you start ollama docker image?")
# Summary
all_good = (
vectors_match
and records_match
and vectors_equal
and records_equal
and embedding_ok
)
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
finally:
# Clean up temporary file
if test_file.exists():
test_file.unlink()
print(f" Cleaned up temporary file: {test_file}")
def create_database(db_path: str):
"""Create a new empty database."""
db_file = Path(db_path)
# Check if file already exists
if db_file.exists():
response = input(f"Database {db_file} already exists. Overwrite? (y/N): ")
if response.lower() != "y":
print("Operation cancelled.")
return
# Create empty database
empty_db = db.create_empty()
# Save to file
db.save(empty_db, db_file)
print(f"✅ Created empty database: {db_file}")
print(f" Vectors: {len(empty_db.vectors)}")
print(f" Records: {len(empty_db.records)}")
def add_file(db_path: str, file_paths: list[str]):
"""Add one or more files to the semantic search database."""
print(f"Adding {len(file_paths)} file(s) to database: {db_path}")
db_file = Path(db_path)
successful_files = []
failed_files = []
for i, file_path in enumerate(file_paths, 1):
print(f"\n[{i}/{len(file_paths)}] Processing: {file_path}")
try:
db.add_document(db_file, Path(file_path))
successful_files.append(file_path)
print(f"✅ Successfully added: {file_path}")
except Exception as e:
failed_files.append((file_path, str(e)))
print(f"❌ Failed to add {file_path}: {e}")
# Summary
print(f"\n{'='*60}")
print("SUMMARY:")
print(f"✅ Successfully added: {len(successful_files)} files")
if successful_files:
for file_path in successful_files:
print(f" - {Path(file_path).name}")
if failed_files:
print(f"❌ Failed to add: {len(failed_files)} files")
for file_path, error in failed_files:
print(f" - {Path(file_path).name}: {error}")
print(f"{'='*60}")
def query(db_path: str, query_text: str):
"""Query the semantic search database."""
print(f"Querying: '{query_text}' in database: {db_path}")
try:
results = db.query(Path(db_path), query_text)
if not results:
print("No results found.")
return
print(f"\nFound {len(results)} results:")
print("=" * 60)
for i, res in enumerate(results, 1):
print(f"\n{i}. Distance: {res.distance:.4f}")
print(f" Document: {res.document_name}")
print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
# Replace all whitespace characters with regular spaces for cleaner display
clean_text = " ".join(res.record.text[:200].split())
print(f" Text preview: {clean_text}...")
if i < len(results):
print("-" * 40)
except Exception as e:
print(f"Error querying database: {e}")
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
"""Start a web server for the semantic search tool."""
try:
# here we intentionally import inside the function to avoid Flask dependency for CLI usage
# pylint: disable=import-outside-toplevel
from flask import Flask, request, jsonify, render_template, send_file
except ImportError:
print("❌ Flask not found. Please install it first:")
print(" pip install flask")
sys.exit(1)
# Set template_folder to 'templates' directory
app = Flask(__name__, template_folder="templates")
db_file = Path(db_path)
# Check if database exists
if not db_file.exists():
print(f"❌ Database file not found: {db_file}")
print(" Create a database first using: python main.py create")
sys.exit(1)
@app.route("/")
def index():
return render_template("index.html", results=None)
@app.route("/file/<int:document_index>")
def serve_file(document_index):
"""Serve PDF files directly."""
try:
file_path = db.get_document_path(db_file, document_index)
if not file_path.exists():
return jsonify({"error": "File not found"}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/search", methods=["POST"])
def search():
try:
data = request.get_json()
if not data or "query" not in data:
return jsonify({"error": "Missing query parameter"}), 400
query_text = data["query"].strip()
if not query_text:
return jsonify({"error": "Query cannot be empty"}), 400
# Perform the search
results = db.query(db_file, query_text)
# Format results for JSON response
formatted_results = []
for res in results:
formatted_results.append(
{
"distance": float(res.distance),
"document_name": res.document_name,
"document_index": res.record.document_index,
"page": res.record.page,
"chunk": res.record.chunk,
"text": " ".join(
res.record.text[:300].split()
), # Clean and truncate text
}
)
return jsonify({"results": formatted_results})
except Exception as e:
return jsonify({"error": str(e)}), 500
print("🚀 Starting web server...")
print(f" Database: {db_file}")
print(f" URL: http://{host}:{port}")
print(" Press Ctrl+C to stop")
try:
app.run(host=host, port=port, debug=False)
except KeyboardInterrupt:
print("\n👋 Web server stopped")
except Exception as e:
print(f"❌ Error starting web server: {e}")
def main():
"""
Main function to parse command-line arguments and execute commands.
"""
parser = argparse.ArgumentParser(
description="Semantic Search Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Create subparsers for different commands
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Create command
create_parser = subparsers.add_parser(
"create", aliases=["c"], help="Create a new empty database"
)
create_parser.add_argument(
"db_path",
nargs="?",
default=str(DEFAULT_DB_PATH),
help=f"Path to database file (default: {DEFAULT_DB_PATH})",
)
# Add file command
add_parser = subparsers.add_parser(
"add-file", aliases=["a"], help="Add one or more files to the search database"
)
add_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
add_parser.add_argument(
"file_paths", nargs="+", help="Path(s) to the PDF file(s) to add"
)
# Query command
query_parser = subparsers.add_parser(
"query", aliases=["q"], help="Query the search database"
)
query_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
query_parser.add_argument("query_text", help="Text to search for")
# Host command (web server)
host_parser = subparsers.add_parser(
"host", aliases=["h"], help="Start a web server for semantic search"
)
host_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
host_parser.add_argument(
"--host",
default="127.0.0.1",
help="Host address to bind to (default: 127.0.0.1)",
)
host_parser.add_argument(
"--port", type=int, default=5000, help="Port to listen on (default: 5000)"
)
# Test command
subparsers.add_parser(
"test", aliases=["t"], help="Test database save/load functionality"
)
# Parse arguments
args = parser.parse_args()
# Handle commands
if args.command in ["create", "c"]:
create_database(args.db_path)
elif args.command in ["add-file", "a"]:
add_file(args.db, args.file_paths)
elif args.command in ["query", "q"]:
query(args.db, args.query_text)
elif args.command in ["host", "h"]:
start_web_server(args.db, args.host, args.port)
elif args.command in ["test", "t"]:
test_database()
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()