From e352780a3d53ca0bfc0c6adc4b4e6d565b0f322a Mon Sep 17 00:00:00 2001 From: Jan Mrna Date: Thu, 6 Nov 2025 10:46:54 +0100 Subject: [PATCH] Format document --- db.py | 136 ++++++++++++++++++-------------- main.py | 240 +++++++++++++++++++++++++++++++++----------------------- 2 files changed, 220 insertions(+), 156 deletions(-) diff --git a/db.py b/db.py index a4eb306..a5961fe 100644 --- a/db.py +++ b/db.py @@ -6,15 +6,16 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np import pymupdf -import ollama # TODO split to another file +import ollama # TODO split to another file # # Types # -type Vector = np.NDArray # np.NDArray[np.float32] ? +type Vector = np.NDArray # np.NDArray[np.float32] ? type VectorBytes = bytes + @dataclass(slots=True) class Record: document_index: int @@ -22,12 +23,14 @@ class Record: text: str chunk: int = 0 # Chunk number within the page (0-indexed) + @dataclass(slots=True) class QueryResult: record: Record distance: float document_name: str + @dataclass(slots=True) class Database: """ @@ -36,41 +39,45 @@ class Database: TODO For faster nearest neighbour lookup we should use something else, e.g. kd-trees """ + vectors: list[Vector] records: dict[VectorBytes, Record] documents: list[Path] - + # # Internal functions # -def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]: +def _find_nearest( + vectors_db: list[Vector], query_vector: Vector, count: int = 10 +) -> list[tuple[float, int]]: """ Find the N nearest vectors to the query embedding. - + Args: vectors_db: List of vectors in the database query_vector: Query embedding vector n: Number of nearest neighbors to return - + Returns: List of (distance, index) tuples, sorted by distance (closest first) """ distances = [np.linalg.norm(x - query_vector) for x in vectors_db] - + # Get indices sorted by distance sorted_indices = np.argsort(distances) - + # Return top N results as (distance, index) tuples results = [] for i in range(min(count, len(sorted_indices))): idx = sorted_indices[i] results.append((float(distances[idx]), int(idx))) - + return results + def _embed(text: str) -> Vector: """ Generate embedding vector for given text. @@ -82,13 +89,15 @@ def _embed(text: str) -> Vector: def _vectorize_record(record: Record) -> tuple[Record, Vector]: return record, _embed(record.text) + # # High-level (exported) functions # + def create_dummy() -> Database: db_length: Final[int] = 10 - vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)] + vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)] records = { vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors } @@ -98,7 +107,7 @@ def create_dummy() -> Database: def create_empty() -> Database: """ Creates a new empty database with no vectors or records. - + Returns: Empty Database object """ @@ -108,105 +117,109 @@ def create_empty() -> Database: def load(database_file: Path) -> Database: """ Loads a database from the given file. - + Args: database_file: Path to the database file - + Returns: Database object loaded from file """ if not database_file.exists(): raise FileNotFoundError(f"Database file not found: {database_file}") - - with open(database_file, 'rb') as f: + + with open(database_file, "rb") as f: serializable_db = pickle.load(f) - + # Reconstruct vectors from bytes vectors = [] - vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64')) - vector_shape = serializable_db.get('vector_shape', ()) - - for vector_bytes in serializable_db['vectors']: + vector_dtype = np.dtype(serializable_db.get("vector_dtype", "float64")) + vector_shape = serializable_db.get("vector_shape", ()) + + for vector_bytes in serializable_db["vectors"]: vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape) vectors.append(vector) - - # Records already use bytes as keys, so we can use them directly - records = serializable_db['records'] - documents = serializable_db['documents'] - + # Records already use bytes as keys, so we can use them directly + records = serializable_db["records"] + + documents = serializable_db["documents"] + return Database(vectors, records, documents) + def save(db: Database, database_file: Path) -> None: """ Saves the database to a file using pickle serialization. - + Args: db: The Database object to save database_file: Path where to save the database file """ # Ensure the directory exists database_file.parent.mkdir(parents=True, exist_ok=True) - + # Create a serializable version of the database # Records already use bytes as keys, so we can use them directly serializable_db = { - 'vectors': [vector.tobytes() for vector in db.vectors], - 'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64', - 'vector_shape': db.vectors[0].shape if db.vectors else (), - 'records': db.records, # Already uses bytes as keys - 'documents': db.documents, + "vectors": [vector.tobytes() for vector in db.vectors], + "vector_dtype": str(db.vectors[0].dtype) if db.vectors else "float64", + "vector_shape": db.vectors[0].shape if db.vectors else (), + "records": db.records, # Already uses bytes as keys + "documents": db.documents, } - + # Save to file - with open(database_file, 'wb') as f: + with open(database_file, "wb") as f: pickle.dump(serializable_db, f) def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]: """ Query the database and return the N nearest records. - + Args: db: Database object or path to database file text: Query text to search for record_count: Number of nearest neighbors to return (default: 10) - + Returns: List of (distance, Record) tuples, sorted by distance (closest first) """ if isinstance(db, Path): db = load(db) - + # Generate embedding for query text query_vector = _embed(text) - + # Find nearest vectors # NOTE We're using euclidean distance as a metric of similarity, # there are some alternatives (cos or dot product) which may be used. # See https://en.wikipedia.org/wiki/Embedding_(machine_learning) nearest_results = _find_nearest(db.vectors, query_vector, record_count) - + # Convert results to (distance, Record) tuples results: list[QueryResult] = [] for distance, vector_idx in nearest_results: # Get the vector at this index vector = db.vectors[vector_idx] vector_bytes = vector.tobytes() - + # Look up the corresponding record if vector_bytes in db.records: record = db.records[vector_bytes] - results.append(QueryResult(record, distance, db.documents[record.document_index].name)) - + results.append( + QueryResult(record, distance, db.documents[record.document_index].name) + ) + return results + def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: """ Adds a new document to the database. If path is given, do load, add, save. Loads PDF with PyMuPDF, splits by pages, and creates records and vectors. Uses multithreading for embedding generation. - + Args: db: Database object or path to database file file: Path to PDF file to add @@ -215,25 +228,25 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: """ save_to_file = False database_file_path = None - + if isinstance(db, Path): database_file_path = db db = load(db) save_to_file = True - + if not file.exists(): raise FileNotFoundError(f"File not found: {file}") - - if file.suffix.lower() != '.pdf': + + if file.suffix.lower() != ".pdf": raise ValueError(f"File must be a PDF: {file}") - + print(f"Processing PDF: {file}") - + document_index = len(db.documents) try: doc = pymupdf.open(file) print(f"PDF opened successfully: {len(doc)} pages") - + records: list[Record] = [] chunk_size = 1024 @@ -243,13 +256,15 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: if not text: print(f" Page {page_num + 1}: Skipped (empty)") continue - + # Simple chunking - split text into chunks of specified size for chunk_idx, i in enumerate(range(0, len(text), chunk_size)): - chunk = text[i:i + chunk_size] + chunk = text[i : i + chunk_size] if chunk_stripped := chunk.strip(): # Only add non-empty chunks # page_num + 1 for use friendliness - records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx)) + records.append( + Record(document_index, page_num + 1, chunk_stripped, chunk_idx) + ) doc.close() except Exception as e: raise RuntimeError(f"Error processing PDF {file}: {e}") @@ -261,34 +276,35 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: # TODO measure with GIL disabled to check if multithreading actually helps with ThreadPoolExecutor(max_workers=max_workers) as pool: - futures = [ pool.submit(_vectorize_record, r) for r in records ] + futures = [pool.submit(_vectorize_record, r) for r in records] for f in as_completed(futures): record, vector = f.result() db.records[vector.tobytes()] = record db.vectors.append(vector) print(f"Successfully processed {file}: {len(records)} chunks") - + # Save database if we loaded it from file if save_to_file and database_file_path: save(db, database_file_path) print(f"Database saved to {database_file_path}") + def get_document_path(db: Database | Path, document_index: int) -> Path: """ Get the file path of the document at the given index in the database. - + Args: db: Database object or path to database file document_index: Index of the document to retrieve - + Returns: Path to the document file """ if isinstance(db, Path): db = load(db) - + if document_index < 0 or document_index >= len(db.documents): raise IndexError(f"Document index out of range: {document_index}") - - return db.documents[document_index] \ No newline at end of file + + return db.documents[document_index] diff --git a/main.py b/main.py index 9365d59..1ecd9ea 100644 --- a/main.py +++ b/main.py @@ -9,56 +9,69 @@ import db DEFAULT_DB_PATH: Final[Path] = Path("db.pkl") + def test_database(): """Test database save/load functionality by creating, saving, loading and comparing.""" print("=== Database Test ===") - + # Create dummy database print("1. Creating dummy database...") original_db = db.create_dummy() - print(f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records") - + print( + f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records" + ) + # Print some details about the original database - print(" First vector shape:", original_db.vectors[0].shape if original_db.vectors else "No vectors") - print(" Sample vector:", original_db.vectors[0][:4] if original_db.vectors else "No vectors") + print( + " First vector shape:", + original_db.vectors[0].shape if original_db.vectors else "No vectors", + ) + print( + " Sample vector:", + original_db.vectors[0][:4] if original_db.vectors else "No vectors", + ) print(" Sample record keys (first 3):", list(original_db.records.keys())[:3]) - + # Create temporary file for testing - with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file: + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file: test_file = Path(tmp_file.name) - + try: # Save database print(f"\n2. Saving database to {test_file}...") db.save(original_db, test_file) print(f" File size: {test_file.stat().st_size} bytes") - + # Load database print(f"\n3. Loading database from {test_file}...") loaded_db = db.load(test_file) - print(f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records") - + print( + f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records" + ) + # Compare databases print("\n4. Comparing original vs loaded...") - + # Check vector count vectors_match = len(original_db.vectors) == len(loaded_db.vectors) print(f" Vector count match: {vectors_match}") - + # Check record count records_match = len(original_db.records) == len(loaded_db.records) print(f" Record count match: {records_match}") - + # Check vector equality vectors_equal = True if vectors_match and original_db.vectors: - for i, (orig, loaded) in enumerate(zip(original_db.vectors, loaded_db.vectors)): + for i, (orig, loaded) in enumerate( + zip(original_db.vectors, loaded_db.vectors) + ): if not np.array_equal(orig, loaded): vectors_equal = False print(f" Vector {i} mismatch!") break print(f" All vectors equal: {vectors_equal}") - + # Check record equality records_equal = True if records_match: @@ -72,21 +85,31 @@ def test_database(): print(" Record content mismatch!") break print(f" All records equal: {records_equal}") - + # Test embedding functionality print("\n5. Testing embedding functionality (Ollama API server)...") try: test_embedding = db._embed("This is a test text for embedding.") - print(f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}") + print( + f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}" + ) ollama_running = True except Exception as e: - print(f" Embedding test FAILED: {e}\n Did you start ollama docker image?") + print( + f" Embedding test FAILED: {e}\n Did you start ollama docker image?" + ) ollama_running = False # Summary - all_good = vectors_match and records_match and vectors_equal and records_equal and ollama_running + all_good = ( + vectors_match + and records_match + and vectors_equal + and records_equal + and ollama_running + ) print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}") - + finally: # Clean up temporary file if test_file.exists(): @@ -97,20 +120,20 @@ def test_database(): def create_database(db_path: str): """Create a new empty database.""" db_file = Path(db_path) - + # Check if file already exists if db_file.exists(): response = input(f"Database {db_file} already exists. Overwrite? (y/N): ") - if response.lower() != 'y': + if response.lower() != "y": print("Operation cancelled.") return - + # Create empty database empty_db = db.create_empty() - + # Save to file db.save(empty_db, db_file) - + print(f"✅ Created empty database: {db_file}") print(f" Vectors: {len(empty_db.vectors)}") print(f" Records: {len(empty_db.records)}") @@ -119,11 +142,11 @@ def create_database(db_path: str): def add_file(db_path: str, file_paths: list[str]): """Add one or more files to the semantic search database.""" print(f"Adding {len(file_paths)} file(s) to database: {db_path}") - + db_file = Path(db_path) successful_files = [] failed_files = [] - + for i, file_path in enumerate(file_paths, 1): print(f"\n[{i}/{len(file_paths)}] Processing: {file_path}") try: @@ -133,7 +156,7 @@ def add_file(db_path: str, file_paths: list[str]): except Exception as e: failed_files.append((file_path, str(e))) print(f"❌ Failed to add {file_path}: {e}") - + # Summary print(f"\n{'='*60}") print("SUMMARY:") @@ -141,44 +164,43 @@ def add_file(db_path: str, file_paths: list[str]): if successful_files: for file_path in successful_files: print(f" - {Path(file_path).name}") - + if failed_files: print(f"❌ Failed to add: {len(failed_files)} files") for file_path, error in failed_files: print(f" - {Path(file_path).name}: {error}") - + print(f"{'='*60}") def query(db_path: str, query_text: str): """Query the semantic search database.""" print(f"Querying: '{query_text}' in database: {db_path}") - + try: results = db.query(Path(db_path), query_text) - + if not results: print("No results found.") return - + print(f"\nFound {len(results)} results:") print("=" * 60) - + for i, res in enumerate(results, 1): print(f"\n{i}. Distance: {res.distance:.4f}") print(f" Document: {res.document_name}") print(f" Page: {res.record.page}, Chunk: {res.record.chunk}") # Replace all whitespace characters with regular spaces for cleaner display - clean_text = ' '.join(res.record.text[:200].split()) + clean_text = " ".join(res.record.text[:200].split()) print(f" Text preview: {clean_text}...") if i < len(results): print("-" * 40) - + except Exception as e: print(f"Error querying database: {e}") - def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000): """Start a web server for the semantic search tool.""" try: @@ -190,63 +212,67 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000): # Set template_folder to 'templates' directory app = Flask(__name__, template_folder="templates") db_file = Path(db_path) - + # Check if database exists if not db_file.exists(): print(f"❌ Database file not found: {db_file}") print(" Create a database first using: python main.py create") sys.exit(1) - - @app.route('/') + + @app.route("/") def index(): return render_template("index.html", results=None) - - @app.route('/file/') + + @app.route("/file/") def serve_file(document_index): """Serve PDF files directly.""" try: file_path = db.get_document_path(db_file, document_index) if not file_path.exists(): - return jsonify({'error': 'File not found'}), 404 + return jsonify({"error": "File not found"}), 404 return send_file(file_path, as_attachment=False) except Exception as e: - return jsonify({'error': str(e)}), 500 - - @app.route('/api/search', methods=['POST']) + return jsonify({"error": str(e)}), 500 + + @app.route("/api/search", methods=["POST"]) def search(): try: data = request.get_json() - if not data or 'query' not in data: - return jsonify({'error': 'Missing query parameter'}), 400 - - query_text = data['query'].strip() + if not data or "query" not in data: + return jsonify({"error": "Missing query parameter"}), 400 + + query_text = data["query"].strip() if not query_text: - return jsonify({'error': 'Query cannot be empty'}), 400 - + return jsonify({"error": "Query cannot be empty"}), 400 + # Perform the search results = db.query(db_file, query_text) - + # Format results for JSON response formatted_results = [] for res in results: - formatted_results.append({ - 'distance': float(res.distance), - 'document_name': res.document_name, - 'document_index': res.record.document_index, - 'page': res.record.page, - 'chunk': res.record.chunk, - 'text': ' '.join(res.record.text[:300].split()) # Clean and truncate text - }) - return jsonify({'results': formatted_results}) - + formatted_results.append( + { + "distance": float(res.distance), + "document_name": res.document_name, + "document_index": res.record.document_index, + "page": res.record.page, + "chunk": res.record.chunk, + "text": " ".join( + res.record.text[:300].split() + ), # Clean and truncate text + } + ) + return jsonify({"results": formatted_results}) + except Exception as e: - return jsonify({'error': str(e)}), 500 - + return jsonify({"error": str(e)}), 500 + print("🚀 Starting web server...") print(f" Database: {db_file}") print(f" URL: http://{host}:{port}") print(" Press Ctrl+C to stop") - + try: app.run(host=host, port=port, debug=False) except KeyboardInterrupt: @@ -258,49 +284,71 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000): def main(): parser = argparse.ArgumentParser( description="Semantic Search Tool", - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, ) - + # Create subparsers for different commands - subparsers = parser.add_subparsers(dest='command', help='Available commands') - + subparsers = parser.add_subparsers(dest="command", help="Available commands") + # Create command - create_parser = subparsers.add_parser('create', aliases=['c'], help='Create a new empty database') - create_parser.add_argument('db_path', nargs='?', default=str(DEFAULT_DB_PATH), - help=f'Path to database file (default: {DEFAULT_DB_PATH})') - + create_parser = subparsers.add_parser( + "create", aliases=["c"], help="Create a new empty database" + ) + create_parser.add_argument( + "db_path", + nargs="?", + default=str(DEFAULT_DB_PATH), + help=f"Path to database file (default: {DEFAULT_DB_PATH})", + ) + # Add file command - add_parser = subparsers.add_parser('add-file', aliases=['a'], help='Add one or more files to the search database') - add_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') - add_parser.add_argument('file_paths', nargs='+', help='Path(s) to the PDF file(s) to add') - + add_parser = subparsers.add_parser( + "add-file", aliases=["a"], help="Add one or more files to the search database" + ) + add_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)") + add_parser.add_argument( + "file_paths", nargs="+", help="Path(s) to the PDF file(s) to add" + ) + # Query command - query_parser = subparsers.add_parser('query', aliases=['q'], help='Query the search database') - query_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') - query_parser.add_argument('query_text', help='Text to search for') - + query_parser = subparsers.add_parser( + "query", aliases=["q"], help="Query the search database" + ) + query_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)") + query_parser.add_argument("query_text", help="Text to search for") + # Host command (web server) - host_parser = subparsers.add_parser('host', aliases=['h'], help='Start a web server for semantic search') - host_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') - host_parser.add_argument('--host', default='127.0.0.1', help='Host address to bind to (default: 127.0.0.1)') - host_parser.add_argument('--port', type=int, default=5000, help='Port to listen on (default: 5000)') - + host_parser = subparsers.add_parser( + "host", aliases=["h"], help="Start a web server for semantic search" + ) + host_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)") + host_parser.add_argument( + "--host", + default="127.0.0.1", + help="Host address to bind to (default: 127.0.0.1)", + ) + host_parser.add_argument( + "--port", type=int, default=5000, help="Port to listen on (default: 5000)" + ) + # Test command - subparsers.add_parser('test', aliases=['t'], help='Test database save/load functionality') - + subparsers.add_parser( + "test", aliases=["t"], help="Test database save/load functionality" + ) + # Parse arguments args = parser.parse_args() - + # Handle commands - if args.command in ['create', 'c']: + if args.command in ["create", "c"]: create_database(args.db_path) - elif args.command in ['add-file', 'a']: + elif args.command in ["add-file", "a"]: add_file(args.db, args.file_paths) - elif args.command in ['query', 'q']: + elif args.command in ["query", "q"]: query(args.db, args.query_text) - elif args.command in ['host', 'h']: + elif args.command in ["host", "h"]: start_web_server(args.db, args.host, args.port) - elif args.command in ['test', 't']: + elif args.command in ["test", "t"]: test_database() else: parser.print_help()