import argparse import sys from pathlib import Path import tempfile import numpy as np from typing import Final import db DEFAULT_DB_PATH: Final[Path] = Path("db.pkl") def test_database(): """Test database save/load functionality by creating, saving, loading and comparing.""" print("=== Database Test ===") # Create dummy database print("1. Creating dummy database...") original_db = db.create_dummy() print(f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records") # Print some details about the original database print(" First vector shape:", original_db.vectors[0].shape if original_db.vectors else "No vectors") print(" Sample vector:", original_db.vectors[0][:4] if original_db.vectors else "No vectors") print(" Sample record keys (first 3):", list(original_db.records.keys())[:3]) # Create temporary file for testing with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file: test_file = Path(tmp_file.name) try: # Save database print(f"\n2. Saving database to {test_file}...") db.save(original_db, test_file) print(f" File size: {test_file.stat().st_size} bytes") # Load database print(f"\n3. Loading database from {test_file}...") loaded_db = db.load(test_file) print(f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records") # Compare databases print("\n4. Comparing original vs loaded...") # Check vector count vectors_match = len(original_db.vectors) == len(loaded_db.vectors) print(f" Vector count match: {vectors_match}") # Check record count records_match = len(original_db.records) == len(loaded_db.records) print(f" Record count match: {records_match}") # Check vector equality vectors_equal = True if vectors_match and original_db.vectors: for i, (orig, loaded) in enumerate(zip(original_db.vectors, loaded_db.vectors)): if not np.array_equal(orig, loaded): vectors_equal = False print(f" Vector {i} mismatch!") break print(f" All vectors equal: {vectors_equal}") # Check record equality records_equal = True if records_match: for key in original_db.records: if key not in loaded_db.records: records_equal = False print(" Missing record key!") break if original_db.records[key] != loaded_db.records[key]: records_equal = False print(" Record content mismatch!") break print(f" All records equal: {records_equal}") # Test embedding functionality print("\n5. Testing embedding functionality (Ollama API server)...") try: test_embedding = db._embed("This is a test text for embedding.") print(f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}") ollama_running = True except Exception as e: print(f" Embedding test FAILED: {e}\n Did you start ollama docker image?") ollama_running = False # Summary all_good = vectors_match and records_match and vectors_equal and records_equal and ollama_running print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}") finally: # Clean up temporary file if test_file.exists(): test_file.unlink() print(f" Cleaned up temporary file: {test_file}") def create_database(db_path: str): """Create a new empty database.""" db_file = Path(db_path) # Check if file already exists if db_file.exists(): response = input(f"Database {db_file} already exists. Overwrite? (y/N): ") if response.lower() != 'y': print("Operation cancelled.") return # Create empty database empty_db = db.create_empty() # Save to file db.save(empty_db, db_file) print(f"✅ Created empty database: {db_file}") print(f" Vectors: {len(empty_db.vectors)}") print(f" Records: {len(empty_db.records)}") def add_file(db_path: str, file_paths: list[str]): """Add one or more files to the semantic search database.""" print(f"Adding {len(file_paths)} file(s) to database: {db_path}") db_file = Path(db_path) successful_files = [] failed_files = [] for i, file_path in enumerate(file_paths, 1): print(f"\n[{i}/{len(file_paths)}] Processing: {file_path}") try: db.add_document(db_file, Path(file_path)) successful_files.append(file_path) print(f"✅ Successfully added: {file_path}") except Exception as e: failed_files.append((file_path, str(e))) print(f"❌ Failed to add {file_path}: {e}") # Summary print(f"\n{'='*60}") print("SUMMARY:") print(f"✅ Successfully added: {len(successful_files)} files") if successful_files: for file_path in successful_files: print(f" - {Path(file_path).name}") if failed_files: print(f"❌ Failed to add: {len(failed_files)} files") for file_path, error in failed_files: print(f" - {Path(file_path).name}: {error}") print(f"{'='*60}") def query(db_path: str, query_text: str): """Query the semantic search database.""" print(f"Querying: '{query_text}' in database: {db_path}") try: results = db.query(Path(db_path), query_text) if not results: print("No results found.") return print(f"\nFound {len(results)} results:") print("=" * 60) for i, (distance, record) in enumerate(results, 1): print(f"\n{i}. Distance: {distance:.4f}") print(f" Document: {record.document.name}") print(f" Page: {record.page}, Chunk: {record.chunk}") # Replace all whitespace characters with regular spaces for cleaner display clean_text = ' '.join(record.text[:200].split()) print(f" Text preview: {clean_text}...") if i < len(results): print("-" * 40) except Exception as e: print(f"Error querying database: {e}") def main(): parser = argparse.ArgumentParser( description="Semantic Search Tool", formatter_class=argparse.RawDescriptionHelpFormatter ) # Create subparsers for different commands subparsers = parser.add_subparsers(dest='command', help='Available commands') # Create command create_parser = subparsers.add_parser('create', aliases=['c'], help='Create a new empty database') create_parser.add_argument('db_path', nargs='?', default=str(DEFAULT_DB_PATH), help=f'Path to database file (default: {DEFAULT_DB_PATH})') # Add file command add_parser = subparsers.add_parser('add-file', aliases=['a'], help='Add one or more files to the search database') add_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') add_parser.add_argument('file_paths', nargs='+', help='Path(s) to the PDF file(s) to add') # Query command query_parser = subparsers.add_parser('query', aliases=['q'], help='Query the search database') query_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') query_parser.add_argument('query_text', help='Text to search for') # Test command subparsers.add_parser('test', aliases=['t'], help='Test database save/load functionality') # Parse arguments args = parser.parse_args() # Handle commands if args.command in ['create', 'c']: create_database(args.db_path) elif args.command in ['add-file', 'a']: add_file(args.db, args.file_paths) elif args.command in ['query', 'q']: query(args.db, args.query_text) elif args.command in ['test', 't']: test_database() else: parser.print_help() sys.exit(1) if __name__ == "__main__": main()