301 lines
11 KiB
Python
301 lines
11 KiB
Python
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
import tempfile
|
|
import numpy as np
|
|
from typing import Final
|
|
|
|
import db
|
|
|
|
DEFAULT_DB_PATH: Final[Path] = Path("db.pkl")
|
|
|
|
def test_database():
|
|
"""Test database save/load functionality by creating, saving, loading and comparing."""
|
|
print("=== Database Test ===")
|
|
|
|
# Create dummy database
|
|
print("1. Creating dummy database...")
|
|
original_db = db.create_dummy()
|
|
print(f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records")
|
|
|
|
# Print some details about the original database
|
|
print(" First vector shape:", original_db.vectors[0].shape if original_db.vectors else "No vectors")
|
|
print(" Sample vector:", original_db.vectors[0][:4] if original_db.vectors else "No vectors")
|
|
print(" Sample record keys (first 3):", list(original_db.records.keys())[:3])
|
|
|
|
# Create temporary file for testing
|
|
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file:
|
|
test_file = Path(tmp_file.name)
|
|
|
|
try:
|
|
# Save database
|
|
print(f"\n2. Saving database to {test_file}...")
|
|
db.save(original_db, test_file)
|
|
print(f" File size: {test_file.stat().st_size} bytes")
|
|
|
|
# Load database
|
|
print(f"\n3. Loading database from {test_file}...")
|
|
loaded_db = db.load(test_file)
|
|
print(f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records")
|
|
|
|
# Compare databases
|
|
print("\n4. Comparing original vs loaded...")
|
|
|
|
# Check vector count
|
|
vectors_match = len(original_db.vectors) == len(loaded_db.vectors)
|
|
print(f" Vector count match: {vectors_match}")
|
|
|
|
# Check record count
|
|
records_match = len(original_db.records) == len(loaded_db.records)
|
|
print(f" Record count match: {records_match}")
|
|
|
|
# Check vector equality
|
|
vectors_equal = True
|
|
if vectors_match and original_db.vectors:
|
|
for i, (orig, loaded) in enumerate(zip(original_db.vectors, loaded_db.vectors)):
|
|
if not np.array_equal(orig, loaded):
|
|
vectors_equal = False
|
|
print(f" Vector {i} mismatch!")
|
|
break
|
|
print(f" All vectors equal: {vectors_equal}")
|
|
|
|
# Check record equality
|
|
records_equal = True
|
|
if records_match:
|
|
for key in original_db.records:
|
|
if key not in loaded_db.records:
|
|
records_equal = False
|
|
print(" Missing record key!")
|
|
break
|
|
if original_db.records[key] != loaded_db.records[key]:
|
|
records_equal = False
|
|
print(" Record content mismatch!")
|
|
break
|
|
print(f" All records equal: {records_equal}")
|
|
|
|
# Test embedding functionality
|
|
print("\n5. Testing embedding functionality (Ollama API server)...")
|
|
try:
|
|
test_embedding = db._embed("This is a test text for embedding.")
|
|
print(f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}")
|
|
ollama_running = True
|
|
except Exception as e:
|
|
print(f" Embedding test FAILED: {e}\n Did you start ollama docker image?")
|
|
ollama_running = False
|
|
|
|
# Summary
|
|
all_good = vectors_match and records_match and vectors_equal and records_equal and ollama_running
|
|
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
|
|
|
|
finally:
|
|
# Clean up temporary file
|
|
if test_file.exists():
|
|
test_file.unlink()
|
|
print(f" Cleaned up temporary file: {test_file}")
|
|
|
|
|
|
def create_database(db_path: str):
|
|
"""Create a new empty database."""
|
|
db_file = Path(db_path)
|
|
|
|
# Check if file already exists
|
|
if db_file.exists():
|
|
response = input(f"Database {db_file} already exists. Overwrite? (y/N): ")
|
|
if response.lower() != 'y':
|
|
print("Operation cancelled.")
|
|
return
|
|
|
|
# Create empty database
|
|
empty_db = db.create_empty()
|
|
|
|
# Save to file
|
|
db.save(empty_db, db_file)
|
|
|
|
print(f"✅ Created empty database: {db_file}")
|
|
print(f" Vectors: {len(empty_db.vectors)}")
|
|
print(f" Records: {len(empty_db.records)}")
|
|
|
|
|
|
def add_file(db_path: str, file_paths: list[str]):
|
|
"""Add one or more files to the semantic search database."""
|
|
print(f"Adding {len(file_paths)} file(s) to database: {db_path}")
|
|
|
|
db_file = Path(db_path)
|
|
successful_files = []
|
|
failed_files = []
|
|
|
|
for i, file_path in enumerate(file_paths, 1):
|
|
print(f"\n[{i}/{len(file_paths)}] Processing: {file_path}")
|
|
try:
|
|
db.add_document(db_file, Path(file_path))
|
|
successful_files.append(file_path)
|
|
print(f"✅ Successfully added: {file_path}")
|
|
except Exception as e:
|
|
failed_files.append((file_path, str(e)))
|
|
print(f"❌ Failed to add {file_path}: {e}")
|
|
|
|
# Summary
|
|
print(f"\n{'='*60}")
|
|
print("SUMMARY:")
|
|
print(f"✅ Successfully added: {len(successful_files)} files")
|
|
if successful_files:
|
|
for file_path in successful_files:
|
|
print(f" - {Path(file_path).name}")
|
|
|
|
if failed_files:
|
|
print(f"❌ Failed to add: {len(failed_files)} files")
|
|
for file_path, error in failed_files:
|
|
print(f" - {Path(file_path).name}: {error}")
|
|
|
|
print(f"{'='*60}")
|
|
|
|
|
|
def query(db_path: str, query_text: str):
|
|
"""Query the semantic search database."""
|
|
print(f"Querying: '{query_text}' in database: {db_path}")
|
|
|
|
try:
|
|
results = db.query(Path(db_path), query_text)
|
|
|
|
if not results:
|
|
print("No results found.")
|
|
return
|
|
|
|
print(f"\nFound {len(results)} results:")
|
|
print("=" * 60)
|
|
|
|
for i, (distance, record) in enumerate(results, 1):
|
|
print(f"\n{i}. Distance: {distance:.4f}")
|
|
print(f" Document: {record.document.name}")
|
|
print(f" Page: {record.page}, Chunk: {record.chunk}")
|
|
# Replace all whitespace characters with regular spaces for cleaner display
|
|
clean_text = ' '.join(record.text[:200].split())
|
|
print(f" Text preview: {clean_text}...")
|
|
if i < len(results):
|
|
print("-" * 40)
|
|
|
|
except Exception as e:
|
|
print(f"Error querying database: {e}")
|
|
|
|
|
|
|
|
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
|
|
"""Start a web server for the semantic search tool."""
|
|
try:
|
|
from flask import Flask, request, jsonify, render_template
|
|
except ImportError:
|
|
print("❌ Flask not found. Please install it first:")
|
|
print(" pip install flask")
|
|
sys.exit(1)
|
|
# Set template_folder to 'templates' directory
|
|
app = Flask(__name__, template_folder="templates")
|
|
db_file = Path(db_path)
|
|
|
|
# Check if database exists
|
|
if not db_file.exists():
|
|
print(f"❌ Database file not found: {db_file}")
|
|
print(" Create a database first using: python main.py create")
|
|
sys.exit(1)
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return render_template("index.html", results=None)
|
|
|
|
@app.route('/api/search', methods=['POST'])
|
|
def search():
|
|
try:
|
|
data = request.get_json()
|
|
if not data or 'query' not in data:
|
|
return jsonify({'error': 'Missing query parameter'}), 400
|
|
|
|
query_text = data['query'].strip()
|
|
if not query_text:
|
|
return jsonify({'error': 'Query cannot be empty'}), 400
|
|
|
|
# Perform the search
|
|
results = db.query(db_file, query_text)
|
|
|
|
# Format results for JSON response
|
|
formatted_results = []
|
|
for distance, record in results:
|
|
formatted_results.append({
|
|
'distance': float(distance),
|
|
'document': record.document.name,
|
|
'page': record.page,
|
|
'chunk': record.chunk,
|
|
'text': ' '.join(record.text[:300].split()) # Clean and truncate text
|
|
})
|
|
|
|
return jsonify({'results': formatted_results})
|
|
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
print("🚀 Starting web server...")
|
|
print(f" Database: {db_file}")
|
|
print(f" URL: http://{host}:{port}")
|
|
print(" Press Ctrl+C to stop")
|
|
|
|
try:
|
|
app.run(host=host, port=port, debug=False)
|
|
except KeyboardInterrupt:
|
|
print("\n👋 Web server stopped")
|
|
except Exception as e:
|
|
print(f"❌ Error starting web server: {e}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Semantic Search Tool",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
)
|
|
|
|
# Create subparsers for different commands
|
|
subparsers = parser.add_subparsers(dest='command', help='Available commands')
|
|
|
|
# Create command
|
|
create_parser = subparsers.add_parser('create', aliases=['c'], help='Create a new empty database')
|
|
create_parser.add_argument('db_path', nargs='?', default=str(DEFAULT_DB_PATH),
|
|
help=f'Path to database file (default: {DEFAULT_DB_PATH})')
|
|
|
|
# Add file command
|
|
add_parser = subparsers.add_parser('add-file', aliases=['a'], help='Add one or more files to the search database')
|
|
add_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
|
|
add_parser.add_argument('file_paths', nargs='+', help='Path(s) to the PDF file(s) to add')
|
|
|
|
# Query command
|
|
query_parser = subparsers.add_parser('query', aliases=['q'], help='Query the search database')
|
|
query_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
|
|
query_parser.add_argument('query_text', help='Text to search for')
|
|
|
|
# Host command (web server)
|
|
host_parser = subparsers.add_parser('host', aliases=['h'], help='Start a web server for semantic search')
|
|
host_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
|
|
host_parser.add_argument('--host', default='127.0.0.1', help='Host address to bind to (default: 127.0.0.1)')
|
|
host_parser.add_argument('--port', type=int, default=5000, help='Port to listen on (default: 5000)')
|
|
|
|
# Test command
|
|
subparsers.add_parser('test', aliases=['t'], help='Test database save/load functionality')
|
|
|
|
# Parse arguments
|
|
args = parser.parse_args()
|
|
|
|
# Handle commands
|
|
if args.command in ['create', 'c']:
|
|
create_database(args.db_path)
|
|
elif args.command in ['add-file', 'a']:
|
|
add_file(args.db, args.file_paths)
|
|
elif args.command in ['query', 'q']:
|
|
query(args.db, args.query_text)
|
|
elif args.command in ['host', 'h']:
|
|
start_web_server(args.db, args.host, args.port)
|
|
elif args.command in ['test', 't']:
|
|
test_database()
|
|
else:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|