Format document

This commit is contained in:
Jan Mrna
2025-11-06 10:46:54 +01:00
parent e734a13a59
commit 7010edae44
2 changed files with 220 additions and 156 deletions

48
db.py
View File

@@ -15,6 +15,7 @@ import ollama # TODO split to another file
type Vector = np.NDArray # np.NDArray[np.float32] ? type Vector = np.NDArray # np.NDArray[np.float32] ?
type VectorBytes = bytes type VectorBytes = bytes
@dataclass(slots=True) @dataclass(slots=True)
class Record: class Record:
document_index: int document_index: int
@@ -22,12 +23,14 @@ class Record:
text: str text: str
chunk: int = 0 # Chunk number within the page (0-indexed) chunk: int = 0 # Chunk number within the page (0-indexed)
@dataclass(slots=True) @dataclass(slots=True)
class QueryResult: class QueryResult:
record: Record record: Record
distance: float distance: float
document_name: str document_name: str
@dataclass(slots=True) @dataclass(slots=True)
class Database: class Database:
""" """
@@ -36,6 +39,7 @@ class Database:
TODO For faster nearest neighbour lookup we should use something else, TODO For faster nearest neighbour lookup we should use something else,
e.g. kd-trees e.g. kd-trees
""" """
vectors: list[Vector] vectors: list[Vector]
records: dict[VectorBytes, Record] records: dict[VectorBytes, Record]
documents: list[Path] documents: list[Path]
@@ -46,7 +50,9 @@ class Database:
# #
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]: def _find_nearest(
vectors_db: list[Vector], query_vector: Vector, count: int = 10
) -> list[tuple[float, int]]:
""" """
Find the N nearest vectors to the query embedding. Find the N nearest vectors to the query embedding.
@@ -71,6 +77,7 @@ def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 1
return results return results
def _embed(text: str) -> Vector: def _embed(text: str) -> Vector:
""" """
Generate embedding vector for given text. Generate embedding vector for given text.
@@ -82,10 +89,12 @@ def _embed(text: str) -> Vector:
def _vectorize_record(record: Record) -> tuple[Record, Vector]: def _vectorize_record(record: Record) -> tuple[Record, Vector]:
return record, _embed(record.text) return record, _embed(record.text)
# #
# High-level (exported) functions # High-level (exported) functions
# #
def create_dummy() -> Database: def create_dummy() -> Database:
db_length: Final[int] = 10 db_length: Final[int] = 10
vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)] vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
@@ -118,25 +127,26 @@ def load(database_file: Path) -> Database:
if not database_file.exists(): if not database_file.exists():
raise FileNotFoundError(f"Database file not found: {database_file}") raise FileNotFoundError(f"Database file not found: {database_file}")
with open(database_file, 'rb') as f: with open(database_file, "rb") as f:
serializable_db = pickle.load(f) serializable_db = pickle.load(f)
# Reconstruct vectors from bytes # Reconstruct vectors from bytes
vectors = [] vectors = []
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64')) vector_dtype = np.dtype(serializable_db.get("vector_dtype", "float64"))
vector_shape = serializable_db.get('vector_shape', ()) vector_shape = serializable_db.get("vector_shape", ())
for vector_bytes in serializable_db['vectors']: for vector_bytes in serializable_db["vectors"]:
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape) vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
vectors.append(vector) vectors.append(vector)
# Records already use bytes as keys, so we can use them directly # Records already use bytes as keys, so we can use them directly
records = serializable_db['records'] records = serializable_db["records"]
documents = serializable_db['documents'] documents = serializable_db["documents"]
return Database(vectors, records, documents) return Database(vectors, records, documents)
def save(db: Database, database_file: Path) -> None: def save(db: Database, database_file: Path) -> None:
""" """
Saves the database to a file using pickle serialization. Saves the database to a file using pickle serialization.
@@ -151,15 +161,15 @@ def save(db: Database, database_file: Path) -> None:
# Create a serializable version of the database # Create a serializable version of the database
# Records already use bytes as keys, so we can use them directly # Records already use bytes as keys, so we can use them directly
serializable_db = { serializable_db = {
'vectors': [vector.tobytes() for vector in db.vectors], "vectors": [vector.tobytes() for vector in db.vectors],
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64', "vector_dtype": str(db.vectors[0].dtype) if db.vectors else "float64",
'vector_shape': db.vectors[0].shape if db.vectors else (), "vector_shape": db.vectors[0].shape if db.vectors else (),
'records': db.records, # Already uses bytes as keys "records": db.records, # Already uses bytes as keys
'documents': db.documents, "documents": db.documents,
} }
# Save to file # Save to file
with open(database_file, 'wb') as f: with open(database_file, "wb") as f:
pickle.dump(serializable_db, f) pickle.dump(serializable_db, f)
@@ -197,10 +207,13 @@ def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryR
# Look up the corresponding record # Look up the corresponding record
if vector_bytes in db.records: if vector_bytes in db.records:
record = db.records[vector_bytes] record = db.records[vector_bytes]
results.append(QueryResult(record, distance, db.documents[record.document_index].name)) results.append(
QueryResult(record, distance, db.documents[record.document_index].name)
)
return results return results
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
""" """
Adds a new document to the database. If path is given, do load, add, save. Adds a new document to the database. If path is given, do load, add, save.
@@ -224,7 +237,7 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
if not file.exists(): if not file.exists():
raise FileNotFoundError(f"File not found: {file}") raise FileNotFoundError(f"File not found: {file}")
if file.suffix.lower() != '.pdf': if file.suffix.lower() != ".pdf":
raise ValueError(f"File must be a PDF: {file}") raise ValueError(f"File must be a PDF: {file}")
print(f"Processing PDF: {file}") print(f"Processing PDF: {file}")
@@ -249,7 +262,9 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
chunk = text[i : i + chunk_size] chunk = text[i : i + chunk_size]
if chunk_stripped := chunk.strip(): # Only add non-empty chunks if chunk_stripped := chunk.strip(): # Only add non-empty chunks
# page_num + 1 for use friendliness # page_num + 1 for use friendliness
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx)) records.append(
Record(document_index, page_num + 1, chunk_stripped, chunk_idx)
)
doc.close() doc.close()
except Exception as e: except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}") raise RuntimeError(f"Error processing PDF {file}: {e}")
@@ -274,6 +289,7 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
save(db, database_file_path) save(db, database_file_path)
print(f"Database saved to {database_file_path}") print(f"Database saved to {database_file_path}")
def get_document_path(db: Database | Path, document_index: int) -> Path: def get_document_path(db: Database | Path, document_index: int) -> Path:
""" """
Get the file path of the document at the given index in the database. Get the file path of the document at the given index in the database.

152
main.py
View File

@@ -9,6 +9,7 @@ import db
DEFAULT_DB_PATH: Final[Path] = Path("db.pkl") DEFAULT_DB_PATH: Final[Path] = Path("db.pkl")
def test_database(): def test_database():
"""Test database save/load functionality by creating, saving, loading and comparing.""" """Test database save/load functionality by creating, saving, loading and comparing."""
print("=== Database Test ===") print("=== Database Test ===")
@@ -16,15 +17,23 @@ def test_database():
# Create dummy database # Create dummy database
print("1. Creating dummy database...") print("1. Creating dummy database...")
original_db = db.create_dummy() original_db = db.create_dummy()
print(f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records") print(
f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records"
)
# Print some details about the original database # Print some details about the original database
print(" First vector shape:", original_db.vectors[0].shape if original_db.vectors else "No vectors") print(
print(" Sample vector:", original_db.vectors[0][:4] if original_db.vectors else "No vectors") " First vector shape:",
original_db.vectors[0].shape if original_db.vectors else "No vectors",
)
print(
" Sample vector:",
original_db.vectors[0][:4] if original_db.vectors else "No vectors",
)
print(" Sample record keys (first 3):", list(original_db.records.keys())[:3]) print(" Sample record keys (first 3):", list(original_db.records.keys())[:3])
# Create temporary file for testing # Create temporary file for testing
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file: with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file:
test_file = Path(tmp_file.name) test_file = Path(tmp_file.name)
try: try:
@@ -36,7 +45,9 @@ def test_database():
# Load database # Load database
print(f"\n3. Loading database from {test_file}...") print(f"\n3. Loading database from {test_file}...")
loaded_db = db.load(test_file) loaded_db = db.load(test_file)
print(f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records") print(
f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records"
)
# Compare databases # Compare databases
print("\n4. Comparing original vs loaded...") print("\n4. Comparing original vs loaded...")
@@ -52,7 +63,9 @@ def test_database():
# Check vector equality # Check vector equality
vectors_equal = True vectors_equal = True
if vectors_match and original_db.vectors: if vectors_match and original_db.vectors:
for i, (orig, loaded) in enumerate(zip(original_db.vectors, loaded_db.vectors)): for i, (orig, loaded) in enumerate(
zip(original_db.vectors, loaded_db.vectors)
):
if not np.array_equal(orig, loaded): if not np.array_equal(orig, loaded):
vectors_equal = False vectors_equal = False
print(f" Vector {i} mismatch!") print(f" Vector {i} mismatch!")
@@ -77,14 +90,24 @@ def test_database():
print("\n5. Testing embedding functionality (Ollama API server)...") print("\n5. Testing embedding functionality (Ollama API server)...")
try: try:
test_embedding = db._embed("This is a test text for embedding.") test_embedding = db._embed("This is a test text for embedding.")
print(f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}") print(
f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}"
)
ollama_running = True ollama_running = True
except Exception as e: except Exception as e:
print(f" Embedding test FAILED: {e}\n Did you start ollama docker image?") print(
f" Embedding test FAILED: {e}\n Did you start ollama docker image?"
)
ollama_running = False ollama_running = False
# Summary # Summary
all_good = vectors_match and records_match and vectors_equal and records_equal and ollama_running all_good = (
vectors_match
and records_match
and vectors_equal
and records_equal
and ollama_running
)
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}") print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
finally: finally:
@@ -101,7 +124,7 @@ def create_database(db_path: str):
# Check if file already exists # Check if file already exists
if db_file.exists(): if db_file.exists():
response = input(f"Database {db_file} already exists. Overwrite? (y/N): ") response = input(f"Database {db_file} already exists. Overwrite? (y/N): ")
if response.lower() != 'y': if response.lower() != "y":
print("Operation cancelled.") print("Operation cancelled.")
return return
@@ -169,7 +192,7 @@ def query(db_path: str, query_text: str):
print(f" Document: {res.document_name}") print(f" Document: {res.document_name}")
print(f" Page: {res.record.page}, Chunk: {res.record.chunk}") print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
# Replace all whitespace characters with regular spaces for cleaner display # Replace all whitespace characters with regular spaces for cleaner display
clean_text = ' '.join(res.record.text[:200].split()) clean_text = " ".join(res.record.text[:200].split())
print(f" Text preview: {clean_text}...") print(f" Text preview: {clean_text}...")
if i < len(results): if i < len(results):
print("-" * 40) print("-" * 40)
@@ -178,7 +201,6 @@ def query(db_path: str, query_text: str):
print(f"Error querying database: {e}") print(f"Error querying database: {e}")
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000): def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
"""Start a web server for the semantic search tool.""" """Start a web server for the semantic search tool."""
try: try:
@@ -197,31 +219,31 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
print(" Create a database first using: python main.py create") print(" Create a database first using: python main.py create")
sys.exit(1) sys.exit(1)
@app.route('/') @app.route("/")
def index(): def index():
return render_template("index.html", results=None) return render_template("index.html", results=None)
@app.route('/file/<int:document_index>') @app.route("/file/<int:document_index>")
def serve_file(document_index): def serve_file(document_index):
"""Serve PDF files directly.""" """Serve PDF files directly."""
try: try:
file_path = db.get_document_path(db_file, document_index) file_path = db.get_document_path(db_file, document_index)
if not file_path.exists(): if not file_path.exists():
return jsonify({'error': 'File not found'}), 404 return jsonify({"error": "File not found"}), 404
return send_file(file_path, as_attachment=False) return send_file(file_path, as_attachment=False)
except Exception as e: except Exception as e:
return jsonify({'error': str(e)}), 500 return jsonify({"error": str(e)}), 500
@app.route('/api/search', methods=['POST']) @app.route("/api/search", methods=["POST"])
def search(): def search():
try: try:
data = request.get_json() data = request.get_json()
if not data or 'query' not in data: if not data or "query" not in data:
return jsonify({'error': 'Missing query parameter'}), 400 return jsonify({"error": "Missing query parameter"}), 400
query_text = data['query'].strip() query_text = data["query"].strip()
if not query_text: if not query_text:
return jsonify({'error': 'Query cannot be empty'}), 400 return jsonify({"error": "Query cannot be empty"}), 400
# Perform the search # Perform the search
results = db.query(db_file, query_text) results = db.query(db_file, query_text)
@@ -229,18 +251,22 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
# Format results for JSON response # Format results for JSON response
formatted_results = [] formatted_results = []
for res in results: for res in results:
formatted_results.append({ formatted_results.append(
'distance': float(res.distance), {
'document_name': res.document_name, "distance": float(res.distance),
'document_index': res.record.document_index, "document_name": res.document_name,
'page': res.record.page, "document_index": res.record.document_index,
'chunk': res.record.chunk, "page": res.record.page,
'text': ' '.join(res.record.text[:300].split()) # Clean and truncate text "chunk": res.record.chunk,
}) "text": " ".join(
return jsonify({'results': formatted_results}) res.record.text[:300].split()
), # Clean and truncate text
}
)
return jsonify({"results": formatted_results})
except Exception as e: except Exception as e:
return jsonify({'error': str(e)}), 500 return jsonify({"error": str(e)}), 500
print("🚀 Starting web server...") print("🚀 Starting web server...")
print(f" Database: {db_file}") print(f" Database: {db_file}")
@@ -258,49 +284,71 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Semantic Search Tool", description="Semantic Search Tool",
formatter_class=argparse.RawDescriptionHelpFormatter formatter_class=argparse.RawDescriptionHelpFormatter,
) )
# Create subparsers for different commands # Create subparsers for different commands
subparsers = parser.add_subparsers(dest='command', help='Available commands') subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Create command # Create command
create_parser = subparsers.add_parser('create', aliases=['c'], help='Create a new empty database') create_parser = subparsers.add_parser(
create_parser.add_argument('db_path', nargs='?', default=str(DEFAULT_DB_PATH), "create", aliases=["c"], help="Create a new empty database"
help=f'Path to database file (default: {DEFAULT_DB_PATH})') )
create_parser.add_argument(
"db_path",
nargs="?",
default=str(DEFAULT_DB_PATH),
help=f"Path to database file (default: {DEFAULT_DB_PATH})",
)
# Add file command # Add file command
add_parser = subparsers.add_parser('add-file', aliases=['a'], help='Add one or more files to the search database') add_parser = subparsers.add_parser(
add_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') "add-file", aliases=["a"], help="Add one or more files to the search database"
add_parser.add_argument('file_paths', nargs='+', help='Path(s) to the PDF file(s) to add') )
add_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
add_parser.add_argument(
"file_paths", nargs="+", help="Path(s) to the PDF file(s) to add"
)
# Query command # Query command
query_parser = subparsers.add_parser('query', aliases=['q'], help='Query the search database') query_parser = subparsers.add_parser(
query_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') "query", aliases=["q"], help="Query the search database"
query_parser.add_argument('query_text', help='Text to search for') )
query_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
query_parser.add_argument("query_text", help="Text to search for")
# Host command (web server) # Host command (web server)
host_parser = subparsers.add_parser('host', aliases=['h'], help='Start a web server for semantic search') host_parser = subparsers.add_parser(
host_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)') "host", aliases=["h"], help="Start a web server for semantic search"
host_parser.add_argument('--host', default='127.0.0.1', help='Host address to bind to (default: 127.0.0.1)') )
host_parser.add_argument('--port', type=int, default=5000, help='Port to listen on (default: 5000)') host_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
host_parser.add_argument(
"--host",
default="127.0.0.1",
help="Host address to bind to (default: 127.0.0.1)",
)
host_parser.add_argument(
"--port", type=int, default=5000, help="Port to listen on (default: 5000)"
)
# Test command # Test command
subparsers.add_parser('test', aliases=['t'], help='Test database save/load functionality') subparsers.add_parser(
"test", aliases=["t"], help="Test database save/load functionality"
)
# Parse arguments # Parse arguments
args = parser.parse_args() args = parser.parse_args()
# Handle commands # Handle commands
if args.command in ['create', 'c']: if args.command in ["create", "c"]:
create_database(args.db_path) create_database(args.db_path)
elif args.command in ['add-file', 'a']: elif args.command in ["add-file", "a"]:
add_file(args.db, args.file_paths) add_file(args.db, args.file_paths)
elif args.command in ['query', 'q']: elif args.command in ["query", "q"]:
query(args.db, args.query_text) query(args.db, args.query_text)
elif args.command in ['host', 'h']: elif args.command in ["host", "h"]:
start_web_server(args.db, args.host, args.port) start_web_server(args.db, args.host, args.port)
elif args.command in ['test', 't']: elif args.command in ["test", "t"]:
test_database() test_database()
else: else:
parser.print_help() parser.print_help()