Format document

This commit is contained in:
Jan Mrna
2025-11-06 10:46:54 +01:00
parent 2fb7a7d224
commit e352780a3d
2 changed files with 220 additions and 156 deletions

58
db.py
View File

@@ -6,15 +6,16 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import pymupdf
import ollama # TODO split to another file
import ollama # TODO split to another file
#
# Types
#
type Vector = np.NDArray # np.NDArray[np.float32] ?
type Vector = np.NDArray # np.NDArray[np.float32] ?
type VectorBytes = bytes
@dataclass(slots=True)
class Record:
document_index: int
@@ -22,12 +23,14 @@ class Record:
text: str
chunk: int = 0 # Chunk number within the page (0-indexed)
@dataclass(slots=True)
class QueryResult:
record: Record
distance: float
document_name: str
@dataclass(slots=True)
class Database:
"""
@@ -36,6 +39,7 @@ class Database:
TODO For faster nearest neighbour lookup we should use something else,
e.g. kd-trees
"""
vectors: list[Vector]
records: dict[VectorBytes, Record]
documents: list[Path]
@@ -46,7 +50,9 @@ class Database:
#
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]:
def _find_nearest(
vectors_db: list[Vector], query_vector: Vector, count: int = 10
) -> list[tuple[float, int]]:
"""
Find the N nearest vectors to the query embedding.
@@ -71,6 +77,7 @@ def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 1
return results
def _embed(text: str) -> Vector:
"""
Generate embedding vector for given text.
@@ -82,13 +89,15 @@ def _embed(text: str) -> Vector:
def _vectorize_record(record: Record) -> tuple[Record, Vector]:
return record, _embed(record.text)
#
# High-level (exported) functions
#
def create_dummy() -> Database:
db_length: Final[int] = 10
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
records = {
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
}
@@ -118,25 +127,26 @@ def load(database_file: Path) -> Database:
if not database_file.exists():
raise FileNotFoundError(f"Database file not found: {database_file}")
with open(database_file, 'rb') as f:
with open(database_file, "rb") as f:
serializable_db = pickle.load(f)
# Reconstruct vectors from bytes
vectors = []
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
vector_shape = serializable_db.get('vector_shape', ())
vector_dtype = np.dtype(serializable_db.get("vector_dtype", "float64"))
vector_shape = serializable_db.get("vector_shape", ())
for vector_bytes in serializable_db['vectors']:
for vector_bytes in serializable_db["vectors"]:
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
vectors.append(vector)
# Records already use bytes as keys, so we can use them directly
records = serializable_db['records']
records = serializable_db["records"]
documents = serializable_db['documents']
documents = serializable_db["documents"]
return Database(vectors, records, documents)
def save(db: Database, database_file: Path) -> None:
"""
Saves the database to a file using pickle serialization.
@@ -151,15 +161,15 @@ def save(db: Database, database_file: Path) -> None:
# Create a serializable version of the database
# Records already use bytes as keys, so we can use them directly
serializable_db = {
'vectors': [vector.tobytes() for vector in db.vectors],
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
'vector_shape': db.vectors[0].shape if db.vectors else (),
'records': db.records, # Already uses bytes as keys
'documents': db.documents,
"vectors": [vector.tobytes() for vector in db.vectors],
"vector_dtype": str(db.vectors[0].dtype) if db.vectors else "float64",
"vector_shape": db.vectors[0].shape if db.vectors else (),
"records": db.records, # Already uses bytes as keys
"documents": db.documents,
}
# Save to file
with open(database_file, 'wb') as f:
with open(database_file, "wb") as f:
pickle.dump(serializable_db, f)
@@ -197,10 +207,13 @@ def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryR
# Look up the corresponding record
if vector_bytes in db.records:
record = db.records[vector_bytes]
results.append(QueryResult(record, distance, db.documents[record.document_index].name))
results.append(
QueryResult(record, distance, db.documents[record.document_index].name)
)
return results
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
"""
Adds a new document to the database. If path is given, do load, add, save.
@@ -224,7 +237,7 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
if not file.exists():
raise FileNotFoundError(f"File not found: {file}")
if file.suffix.lower() != '.pdf':
if file.suffix.lower() != ".pdf":
raise ValueError(f"File must be a PDF: {file}")
print(f"Processing PDF: {file}")
@@ -246,10 +259,12 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
# Simple chunking - split text into chunks of specified size
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
chunk = text[i:i + chunk_size]
chunk = text[i : i + chunk_size]
if chunk_stripped := chunk.strip(): # Only add non-empty chunks
# page_num + 1 for use friendliness
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
records.append(
Record(document_index, page_num + 1, chunk_stripped, chunk_idx)
)
doc.close()
except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}")
@@ -261,7 +276,7 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
# TODO measure with GIL disabled to check if multithreading actually helps
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [ pool.submit(_vectorize_record, r) for r in records ]
futures = [pool.submit(_vectorize_record, r) for r in records]
for f in as_completed(futures):
record, vector = f.result()
db.records[vector.tobytes()] = record
@@ -274,6 +289,7 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
save(db, database_file_path)
print(f"Database saved to {database_file_path}")
def get_document_path(db: Database | Path, document_index: int) -> Path:
"""
Get the file path of the document at the given index in the database.

152
main.py
View File

@@ -9,6 +9,7 @@ import db
DEFAULT_DB_PATH: Final[Path] = Path("db.pkl")
def test_database():
"""Test database save/load functionality by creating, saving, loading and comparing."""
print("=== Database Test ===")
@@ -16,15 +17,23 @@ def test_database():
# Create dummy database
print("1. Creating dummy database...")
original_db = db.create_dummy()
print(f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records")
print(
f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records"
)
# Print some details about the original database
print(" First vector shape:", original_db.vectors[0].shape if original_db.vectors else "No vectors")
print(" Sample vector:", original_db.vectors[0][:4] if original_db.vectors else "No vectors")
print(
" First vector shape:",
original_db.vectors[0].shape if original_db.vectors else "No vectors",
)
print(
" Sample vector:",
original_db.vectors[0][:4] if original_db.vectors else "No vectors",
)
print(" Sample record keys (first 3):", list(original_db.records.keys())[:3])
# Create temporary file for testing
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file:
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file:
test_file = Path(tmp_file.name)
try:
@@ -36,7 +45,9 @@ def test_database():
# Load database
print(f"\n3. Loading database from {test_file}...")
loaded_db = db.load(test_file)
print(f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records")
print(
f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records"
)
# Compare databases
print("\n4. Comparing original vs loaded...")
@@ -52,7 +63,9 @@ def test_database():
# Check vector equality
vectors_equal = True
if vectors_match and original_db.vectors:
for i, (orig, loaded) in enumerate(zip(original_db.vectors, loaded_db.vectors)):
for i, (orig, loaded) in enumerate(
zip(original_db.vectors, loaded_db.vectors)
):
if not np.array_equal(orig, loaded):
vectors_equal = False
print(f" Vector {i} mismatch!")
@@ -77,14 +90,24 @@ def test_database():
print("\n5. Testing embedding functionality (Ollama API server)...")
try:
test_embedding = db._embed("This is a test text for embedding.")
print(f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}")
print(
f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}"
)
ollama_running = True
except Exception as e:
print(f" Embedding test FAILED: {e}\n Did you start ollama docker image?")
print(
f" Embedding test FAILED: {e}\n Did you start ollama docker image?"
)
ollama_running = False
# Summary
all_good = vectors_match and records_match and vectors_equal and records_equal and ollama_running
all_good = (
vectors_match
and records_match
and vectors_equal
and records_equal
and ollama_running
)
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
finally:
@@ -101,7 +124,7 @@ def create_database(db_path: str):
# Check if file already exists
if db_file.exists():
response = input(f"Database {db_file} already exists. Overwrite? (y/N): ")
if response.lower() != 'y':
if response.lower() != "y":
print("Operation cancelled.")
return
@@ -169,7 +192,7 @@ def query(db_path: str, query_text: str):
print(f" Document: {res.document_name}")
print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
# Replace all whitespace characters with regular spaces for cleaner display
clean_text = ' '.join(res.record.text[:200].split())
clean_text = " ".join(res.record.text[:200].split())
print(f" Text preview: {clean_text}...")
if i < len(results):
print("-" * 40)
@@ -178,7 +201,6 @@ def query(db_path: str, query_text: str):
print(f"Error querying database: {e}")
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
"""Start a web server for the semantic search tool."""
try:
@@ -197,31 +219,31 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
print(" Create a database first using: python main.py create")
sys.exit(1)
@app.route('/')
@app.route("/")
def index():
return render_template("index.html", results=None)
@app.route('/file/<int:document_index>')
@app.route("/file/<int:document_index>")
def serve_file(document_index):
"""Serve PDF files directly."""
try:
file_path = db.get_document_path(db_file, document_index)
if not file_path.exists():
return jsonify({'error': 'File not found'}), 404
return jsonify({"error": "File not found"}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': str(e)}), 500
return jsonify({"error": str(e)}), 500
@app.route('/api/search', methods=['POST'])
@app.route("/api/search", methods=["POST"])
def search():
try:
data = request.get_json()
if not data or 'query' not in data:
return jsonify({'error': 'Missing query parameter'}), 400
if not data or "query" not in data:
return jsonify({"error": "Missing query parameter"}), 400
query_text = data['query'].strip()
query_text = data["query"].strip()
if not query_text:
return jsonify({'error': 'Query cannot be empty'}), 400
return jsonify({"error": "Query cannot be empty"}), 400
# Perform the search
results = db.query(db_file, query_text)
@@ -229,18 +251,22 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
# Format results for JSON response
formatted_results = []
for res in results:
formatted_results.append({
'distance': float(res.distance),
'document_name': res.document_name,
'document_index': res.record.document_index,
'page': res.record.page,
'chunk': res.record.chunk,
'text': ' '.join(res.record.text[:300].split()) # Clean and truncate text
})
return jsonify({'results': formatted_results})
formatted_results.append(
{
"distance": float(res.distance),
"document_name": res.document_name,
"document_index": res.record.document_index,
"page": res.record.page,
"chunk": res.record.chunk,
"text": " ".join(
res.record.text[:300].split()
), # Clean and truncate text
}
)
return jsonify({"results": formatted_results})
except Exception as e:
return jsonify({'error': str(e)}), 500
return jsonify({"error": str(e)}), 500
print("🚀 Starting web server...")
print(f" Database: {db_file}")
@@ -258,49 +284,71 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
def main():
parser = argparse.ArgumentParser(
description="Semantic Search Tool",
formatter_class=argparse.RawDescriptionHelpFormatter
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Create subparsers for different commands
subparsers = parser.add_subparsers(dest='command', help='Available commands')
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Create command
create_parser = subparsers.add_parser('create', aliases=['c'], help='Create a new empty database')
create_parser.add_argument('db_path', nargs='?', default=str(DEFAULT_DB_PATH),
help=f'Path to database file (default: {DEFAULT_DB_PATH})')
create_parser = subparsers.add_parser(
"create", aliases=["c"], help="Create a new empty database"
)
create_parser.add_argument(
"db_path",
nargs="?",
default=str(DEFAULT_DB_PATH),
help=f"Path to database file (default: {DEFAULT_DB_PATH})",
)
# Add file command
add_parser = subparsers.add_parser('add-file', aliases=['a'], help='Add one or more files to the search database')
add_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
add_parser.add_argument('file_paths', nargs='+', help='Path(s) to the PDF file(s) to add')
add_parser = subparsers.add_parser(
"add-file", aliases=["a"], help="Add one or more files to the search database"
)
add_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
add_parser.add_argument(
"file_paths", nargs="+", help="Path(s) to the PDF file(s) to add"
)
# Query command
query_parser = subparsers.add_parser('query', aliases=['q'], help='Query the search database')
query_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
query_parser.add_argument('query_text', help='Text to search for')
query_parser = subparsers.add_parser(
"query", aliases=["q"], help="Query the search database"
)
query_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
query_parser.add_argument("query_text", help="Text to search for")
# Host command (web server)
host_parser = subparsers.add_parser('host', aliases=['h'], help='Start a web server for semantic search')
host_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
host_parser.add_argument('--host', default='127.0.0.1', help='Host address to bind to (default: 127.0.0.1)')
host_parser.add_argument('--port', type=int, default=5000, help='Port to listen on (default: 5000)')
host_parser = subparsers.add_parser(
"host", aliases=["h"], help="Start a web server for semantic search"
)
host_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
host_parser.add_argument(
"--host",
default="127.0.0.1",
help="Host address to bind to (default: 127.0.0.1)",
)
host_parser.add_argument(
"--port", type=int, default=5000, help="Port to listen on (default: 5000)"
)
# Test command
subparsers.add_parser('test', aliases=['t'], help='Test database save/load functionality')
subparsers.add_parser(
"test", aliases=["t"], help="Test database save/load functionality"
)
# Parse arguments
args = parser.parse_args()
# Handle commands
if args.command in ['create', 'c']:
if args.command in ["create", "c"]:
create_database(args.db_path)
elif args.command in ['add-file', 'a']:
elif args.command in ["add-file", "a"]:
add_file(args.db, args.file_paths)
elif args.command in ['query', 'q']:
elif args.command in ["query", "q"]:
query(args.db, args.query_text)
elif args.command in ['host', 'h']:
elif args.command in ["host", "h"]:
start_web_server(args.db, args.host, args.port)
elif args.command in ['test', 't']:
elif args.command in ["test", "t"]:
test_database()
else:
parser.print_help()