From 9bc39ccea8c09047e9a54e92732c99bf48048357 Mon Sep 17 00:00:00 2001 From: Jan Mrna Date: Wed, 5 Nov 2025 14:31:36 +0100 Subject: [PATCH] DB Record: use document index instead of Path --- db.py | 164 +++++++++++++++++++++++--------------------------------- main.py | 24 ++++----- 2 files changed, 80 insertions(+), 108 deletions(-) diff --git a/db.py b/db.py index d89f10d..99bf2be 100644 --- a/db.py +++ b/db.py @@ -17,11 +17,17 @@ type VectorBytes = bytes @dataclass(slots=True) class Record: - document: Path + document_index: int page: int text: str chunk: int = 0 # Chunk number within the page (0-indexed) +@dataclass(slots=True) +class QueryResult: + record: Record + distance: float + document: Path + @dataclass(slots=True) class Database: """ @@ -34,6 +40,7 @@ class Database: """ vectors: list[Vector] records: dict[VectorBytes, Record] + documents: list[Path] # @@ -74,32 +81,8 @@ def _embed(text: str) -> Vector: return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"]) -def _process_chunk(chunk_data: tuple) -> tuple[Vector, Record]: - """ - Process a single chunk: generate embedding and create record. - Used for multithreading. - - Args: - chunk_data: Tuple of (file_path, page_num, chunk_idx, chunk_text) - - Returns: - Tuple of (vector, record) - """ - file_path, page_num, chunk_idx, chunk_text = chunk_data - - # Generate embedding vector for this chunk - vector = _embed(chunk_text) - - # Create record for this chunk - record = Record( - document=file_path, - page=page_num + 1, # 1-indexed for user-friendliness - text=chunk_text, - chunk=chunk_idx + 1 # 1-indexed for user-friendliness - ) - - return vector, record - +def _vectorize_record(record: Record) -> Vector: + return _embed(record.text) # # High-level (exported) functions @@ -109,9 +92,9 @@ def create_dummy() -> Database: db_length: Final[int] = 10 vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)] records = { - vector.tobytes(): Record(Path("dummy"), 1, "Lorem my ipsum", 1) for vector in vectors + vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors } - return Database(vectors, records) + return Database(vectors, records, [Path("dummy")]) def create_empty() -> Database: @@ -121,7 +104,7 @@ def create_empty() -> Database: Returns: Empty Database object """ - return Database(vectors=[], records={}) + return Database(vectors=[], records={}, documents=[]) def load(database_file: Path) -> Database: @@ -151,8 +134,10 @@ def load(database_file: Path) -> Database: # Records already use bytes as keys, so we can use them directly records = serializable_db['records'] + + documents = serializable_db['documents'] - return Database(vectors, records) + return Database(vectors, records, documents) def save(db: Database, database_file: Path) -> None: """ @@ -171,7 +156,8 @@ def save(db: Database, database_file: Path) -> None: 'vectors': [vector.tobytes() for vector in db.vectors], 'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64', 'vector_shape': db.vectors[0].shape if db.vectors else (), - 'records': db.records # Already uses bytes as keys + 'records': db.records, # Already uses bytes as keys + 'documents': db.documents, } # Save to file @@ -179,7 +165,7 @@ def save(db: Database, database_file: Path) -> None: pickle.dump(serializable_db, f) -def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[float, Record]]: +def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]: """ Query the database and return the N nearest records. @@ -204,7 +190,7 @@ def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[ nearest_results = _find_nearest(db.vectors, query_vector, record_count) # Convert results to (distance, Record) tuples - results = [] + results: list[QueryResult] = [] for distance, vector_idx in nearest_results: # Get the vector at this index vector = db.vectors[vector_idx] @@ -213,7 +199,7 @@ def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[ # Look up the corresponding record if vector_bytes in db.records: record = db.records[vector_bytes] - results.append((distance, record)) + results.append(QueryResult(record, distance, db.documents[record.document_index])) return results @@ -237,7 +223,6 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: db = load(db) save_to_file = True - # Validate that the file exists and is a PDF if not file.exists(): raise FileNotFoundError(f"File not found: {file}") @@ -246,82 +231,69 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None: print(f"Processing PDF: {file}") - # Open PDF with PyMuPDF + document_index = len(db.documents) try: doc = pymupdf.open(file) print(f"PDF opened successfully: {len(doc)} pages") - # Collect all chunks to process - chunk_tasks = [] - - # Process each page to collect chunks + records: list[Record] = [] + chunk_size = 1024 + for page_num in range(len(doc)): page = doc[page_num] - - # Extract text from page text = page.get_text().strip() - - # Skip empty pages if not text: print(f" Page {page_num + 1}: Skipped (empty)") continue - #print(f" Page {page_num + 1}: {len(text)} characters") - - # Split page text into chunks of 1024 characters - chunk_size = 1024 - chunks = [] - # Simple chunking - split text into chunks of specified size - for i in range(0, len(text), chunk_size): + for chunk_idx, i in enumerate(range(0, len(text), chunk_size)): chunk = text[i:i + chunk_size] - if chunk.strip(): # Only add non-empty chunks - chunks.append(chunk.strip()) - - #print(f" Split into {len(chunks)} chunks") - - # Add chunk tasks for parallel processing - for chunk_idx, chunk_text in enumerate(chunks): - chunk_tasks.append((file, page_num, chunk_idx, chunk_text)) - + if chunk_stripped := chunk.strip(): # Only add non-empty chunks + # page_num + 1 for use friendliness + records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx)) doc.close() - - # Process chunks in parallel - print(f"Processing {len(chunk_tasks)} chunks with {max_workers} workers...") - - # TODO measure with GIL disabled to check if multithreading actually helps - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit all chunk processing tasks - future_to_chunk = { - executor.submit(_process_chunk, chunk_data): chunk_data - for chunk_data in chunk_tasks - } - - # Collect results as they complete - completed_chunks = 0 - for future in as_completed(future_to_chunk): - try: - vector, record = future.result() - - # Convert vector to bytes for use as dictionary key - vector_bytes = vector.tobytes() - - # Add to database - db.vectors.append(vector) - db.records[vector_bytes] = record - - completed_chunks += 1 - # if completed_chunks % 10 == 0 or completed_chunks == len(chunk_tasks): - # print(f" Completed {completed_chunks}/{len(chunk_tasks)} chunks") - - except Exception as e: - chunk_data = future_to_chunk[future] - print(f" Error processing chunk {chunk_data}: {e}") - - print(f"Successfully processed {file}: {len(chunk_tasks)} chunks") - except Exception as e: raise RuntimeError(f"Error processing PDF {file}: {e}") + + # Process chunks in parallel + print(f"Processing {len(records)} chunks with {max_workers} workers...") + + db.documents.append(file) + for record in records: + vector = _vectorize_record(record) + db.records[vector.tobytes()] = record + db.vectors.append(vector) + + # TODO measure with GIL disabled to check if multithreading actually helps + # with ThreadPoolExecutor(max_workers=max_workers) as executor: + # # Submit all chunk processing tasks + # vector_futures = { + # executor.submit(_vectorize_record, r): r for r in records + # } + + # # Collect results as they complete + # completed_chunks = 0 + # for future in as_completed(vector_futures): + # try: + # vector = future.result() + # # Convert vector to bytes for use as dictionary key + # vector_bytes = vector.tobytes() + # # Add to database + # db.vectors.append(vector) + # db.records[vector_bytes] = record + + # completed_chunks += 1 + # # if completed_chunks % 10 == 0 or completed_chunks == len(chunk_tasks): + # # print(f" Completed {completed_chunks}/{len(chunk_tasks)} chunks") + + # except Exception as e: + # chunk_data = future_to_chunk[future] + # print(f" Error processing chunk {chunk_data}: {e}") + + print(f"Successfully processed {file}: {len(records)} chunks") + + # Save database if we loaded it from file if save_to_file and database_file_path: diff --git a/main.py b/main.py index b764c8a..30106e5 100644 --- a/main.py +++ b/main.py @@ -164,12 +164,12 @@ def query(db_path: str, query_text: str): print(f"\nFound {len(results)} results:") print("=" * 60) - for i, (distance, record) in enumerate(results, 1): - print(f"\n{i}. Distance: {distance:.4f}") - print(f" Document: {record.document.name}") - print(f" Page: {record.page}, Chunk: {record.chunk}") + for i, res in enumerate(results, 1): + print(f"\n{i}. Distance: {res.distance:.4f}") + print(f" Document: {res.document.name}") + print(f" Page: {res.record.page}, Chunk: {res.record.chunk}") # Replace all whitespace characters with regular spaces for cleaner display - clean_text = ' '.join(record.text[:200].split()) + clean_text = ' '.join(res.record.text[:200].split()) print(f" Text preview: {clean_text}...") if i < len(results): print("-" * 40) @@ -233,14 +233,14 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000): # Format results for JSON response formatted_results = [] - for distance, record in results: + for res in results: formatted_results.append({ - 'distance': float(distance), - 'document': record.document.name, - 'document_path': str(record.document), # Full path for the link - 'page': record.page, - 'chunk': record.chunk, - 'text': ' '.join(record.text[:300].split()) # Clean and truncate text + 'distance': float(res.distance), + 'document': res.document.name, + 'document_path': str(res.document), # Full path for the link + 'page': res.record.page, + 'chunk': res.record.chunk, + 'text': ' '.join(res.record.text[:300].split()) # Clean and truncate text }) return jsonify({'results': formatted_results})