Files
semantic_doc_search/main.py
2025-11-06 10:46:54 +01:00

360 lines
12 KiB
Python

import argparse
import sys
from pathlib import Path
import tempfile
import numpy as np
from typing import Final
import db
DEFAULT_DB_PATH: Final[Path] = Path("db.pkl")
def test_database():
"""Test database save/load functionality by creating, saving, loading and comparing."""
print("=== Database Test ===")
# Create dummy database
print("1. Creating dummy database...")
original_db = db.create_dummy()
print(
f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records"
)
# Print some details about the original database
print(
" First vector shape:",
original_db.vectors[0].shape if original_db.vectors else "No vectors",
)
print(
" Sample vector:",
original_db.vectors[0][:4] if original_db.vectors else "No vectors",
)
print(" Sample record keys (first 3):", list(original_db.records.keys())[:3])
# Create temporary file for testing
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file:
test_file = Path(tmp_file.name)
try:
# Save database
print(f"\n2. Saving database to {test_file}...")
db.save(original_db, test_file)
print(f" File size: {test_file.stat().st_size} bytes")
# Load database
print(f"\n3. Loading database from {test_file}...")
loaded_db = db.load(test_file)
print(
f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records"
)
# Compare databases
print("\n4. Comparing original vs loaded...")
# Check vector count
vectors_match = len(original_db.vectors) == len(loaded_db.vectors)
print(f" Vector count match: {vectors_match}")
# Check record count
records_match = len(original_db.records) == len(loaded_db.records)
print(f" Record count match: {records_match}")
# Check vector equality
vectors_equal = True
if vectors_match and original_db.vectors:
for i, (orig, loaded) in enumerate(
zip(original_db.vectors, loaded_db.vectors)
):
if not np.array_equal(orig, loaded):
vectors_equal = False
print(f" Vector {i} mismatch!")
break
print(f" All vectors equal: {vectors_equal}")
# Check record equality
records_equal = True
if records_match:
for key in original_db.records:
if key not in loaded_db.records:
records_equal = False
print(" Missing record key!")
break
if original_db.records[key] != loaded_db.records[key]:
records_equal = False
print(" Record content mismatch!")
break
print(f" All records equal: {records_equal}")
# Test embedding functionality
print("\n5. Testing embedding functionality (Ollama API server)...")
try:
test_embedding = db._embed("This is a test text for embedding.")
print(
f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}"
)
ollama_running = True
except Exception as e:
print(
f" Embedding test FAILED: {e}\n Did you start ollama docker image?"
)
ollama_running = False
# Summary
all_good = (
vectors_match
and records_match
and vectors_equal
and records_equal
and ollama_running
)
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
finally:
# Clean up temporary file
if test_file.exists():
test_file.unlink()
print(f" Cleaned up temporary file: {test_file}")
def create_database(db_path: str):
"""Create a new empty database."""
db_file = Path(db_path)
# Check if file already exists
if db_file.exists():
response = input(f"Database {db_file} already exists. Overwrite? (y/N): ")
if response.lower() != "y":
print("Operation cancelled.")
return
# Create empty database
empty_db = db.create_empty()
# Save to file
db.save(empty_db, db_file)
print(f"✅ Created empty database: {db_file}")
print(f" Vectors: {len(empty_db.vectors)}")
print(f" Records: {len(empty_db.records)}")
def add_file(db_path: str, file_paths: list[str]):
"""Add one or more files to the semantic search database."""
print(f"Adding {len(file_paths)} file(s) to database: {db_path}")
db_file = Path(db_path)
successful_files = []
failed_files = []
for i, file_path in enumerate(file_paths, 1):
print(f"\n[{i}/{len(file_paths)}] Processing: {file_path}")
try:
db.add_document(db_file, Path(file_path))
successful_files.append(file_path)
print(f"✅ Successfully added: {file_path}")
except Exception as e:
failed_files.append((file_path, str(e)))
print(f"❌ Failed to add {file_path}: {e}")
# Summary
print(f"\n{'='*60}")
print("SUMMARY:")
print(f"✅ Successfully added: {len(successful_files)} files")
if successful_files:
for file_path in successful_files:
print(f" - {Path(file_path).name}")
if failed_files:
print(f"❌ Failed to add: {len(failed_files)} files")
for file_path, error in failed_files:
print(f" - {Path(file_path).name}: {error}")
print(f"{'='*60}")
def query(db_path: str, query_text: str):
"""Query the semantic search database."""
print(f"Querying: '{query_text}' in database: {db_path}")
try:
results = db.query(Path(db_path), query_text)
if not results:
print("No results found.")
return
print(f"\nFound {len(results)} results:")
print("=" * 60)
for i, res in enumerate(results, 1):
print(f"\n{i}. Distance: {res.distance:.4f}")
print(f" Document: {res.document_name}")
print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
# Replace all whitespace characters with regular spaces for cleaner display
clean_text = " ".join(res.record.text[:200].split())
print(f" Text preview: {clean_text}...")
if i < len(results):
print("-" * 40)
except Exception as e:
print(f"Error querying database: {e}")
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
"""Start a web server for the semantic search tool."""
try:
from flask import Flask, request, jsonify, render_template, send_file
except ImportError:
print("❌ Flask not found. Please install it first:")
print(" pip install flask")
sys.exit(1)
# Set template_folder to 'templates' directory
app = Flask(__name__, template_folder="templates")
db_file = Path(db_path)
# Check if database exists
if not db_file.exists():
print(f"❌ Database file not found: {db_file}")
print(" Create a database first using: python main.py create")
sys.exit(1)
@app.route("/")
def index():
return render_template("index.html", results=None)
@app.route("/file/<int:document_index>")
def serve_file(document_index):
"""Serve PDF files directly."""
try:
file_path = db.get_document_path(db_file, document_index)
if not file_path.exists():
return jsonify({"error": "File not found"}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/search", methods=["POST"])
def search():
try:
data = request.get_json()
if not data or "query" not in data:
return jsonify({"error": "Missing query parameter"}), 400
query_text = data["query"].strip()
if not query_text:
return jsonify({"error": "Query cannot be empty"}), 400
# Perform the search
results = db.query(db_file, query_text)
# Format results for JSON response
formatted_results = []
for res in results:
formatted_results.append(
{
"distance": float(res.distance),
"document_name": res.document_name,
"document_index": res.record.document_index,
"page": res.record.page,
"chunk": res.record.chunk,
"text": " ".join(
res.record.text[:300].split()
), # Clean and truncate text
}
)
return jsonify({"results": formatted_results})
except Exception as e:
return jsonify({"error": str(e)}), 500
print("🚀 Starting web server...")
print(f" Database: {db_file}")
print(f" URL: http://{host}:{port}")
print(" Press Ctrl+C to stop")
try:
app.run(host=host, port=port, debug=False)
except KeyboardInterrupt:
print("\n👋 Web server stopped")
except Exception as e:
print(f"❌ Error starting web server: {e}")
def main():
parser = argparse.ArgumentParser(
description="Semantic Search Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Create subparsers for different commands
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Create command
create_parser = subparsers.add_parser(
"create", aliases=["c"], help="Create a new empty database"
)
create_parser.add_argument(
"db_path",
nargs="?",
default=str(DEFAULT_DB_PATH),
help=f"Path to database file (default: {DEFAULT_DB_PATH})",
)
# Add file command
add_parser = subparsers.add_parser(
"add-file", aliases=["a"], help="Add one or more files to the search database"
)
add_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
add_parser.add_argument(
"file_paths", nargs="+", help="Path(s) to the PDF file(s) to add"
)
# Query command
query_parser = subparsers.add_parser(
"query", aliases=["q"], help="Query the search database"
)
query_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
query_parser.add_argument("query_text", help="Text to search for")
# Host command (web server)
host_parser = subparsers.add_parser(
"host", aliases=["h"], help="Start a web server for semantic search"
)
host_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
host_parser.add_argument(
"--host",
default="127.0.0.1",
help="Host address to bind to (default: 127.0.0.1)",
)
host_parser.add_argument(
"--port", type=int, default=5000, help="Port to listen on (default: 5000)"
)
# Test command
subparsers.add_parser(
"test", aliases=["t"], help="Test database save/load functionality"
)
# Parse arguments
args = parser.parse_args()
# Handle commands
if args.command in ["create", "c"]:
create_database(args.db_path)
elif args.command in ["add-file", "a"]:
add_file(args.db, args.file_paths)
elif args.command in ["query", "q"]:
query(args.db, args.query_text)
elif args.command in ["host", "h"]:
start_web_server(args.db, args.host, args.port)
elif args.command in ["test", "t"]:
test_database()
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()