DB Record: use document index instead of Path

This commit is contained in:
Jan Mrna
2025-11-05 14:31:36 +01:00
parent c0c80142be
commit 9bc39ccea8
2 changed files with 80 additions and 108 deletions

164
db.py
View File

@@ -17,11 +17,17 @@ type VectorBytes = bytes
@dataclass(slots=True) @dataclass(slots=True)
class Record: class Record:
document: Path document_index: int
page: int page: int
text: str text: str
chunk: int = 0 # Chunk number within the page (0-indexed) chunk: int = 0 # Chunk number within the page (0-indexed)
@dataclass(slots=True)
class QueryResult:
record: Record
distance: float
document: Path
@dataclass(slots=True) @dataclass(slots=True)
class Database: class Database:
""" """
@@ -34,6 +40,7 @@ class Database:
""" """
vectors: list[Vector] vectors: list[Vector]
records: dict[VectorBytes, Record] records: dict[VectorBytes, Record]
documents: list[Path]
# #
@@ -74,32 +81,8 @@ def _embed(text: str) -> Vector:
return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"]) return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"])
def _process_chunk(chunk_data: tuple) -> tuple[Vector, Record]: def _vectorize_record(record: Record) -> Vector:
""" return _embed(record.text)
Process a single chunk: generate embedding and create record.
Used for multithreading.
Args:
chunk_data: Tuple of (file_path, page_num, chunk_idx, chunk_text)
Returns:
Tuple of (vector, record)
"""
file_path, page_num, chunk_idx, chunk_text = chunk_data
# Generate embedding vector for this chunk
vector = _embed(chunk_text)
# Create record for this chunk
record = Record(
document=file_path,
page=page_num + 1, # 1-indexed for user-friendliness
text=chunk_text,
chunk=chunk_idx + 1 # 1-indexed for user-friendliness
)
return vector, record
# #
# High-level (exported) functions # High-level (exported) functions
@@ -109,9 +92,9 @@ def create_dummy() -> Database:
db_length: Final[int] = 10 db_length: Final[int] = 10
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)] vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
records = { records = {
vector.tobytes(): Record(Path("dummy"), 1, "Lorem my ipsum", 1) for vector in vectors vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
} }
return Database(vectors, records) return Database(vectors, records, [Path("dummy")])
def create_empty() -> Database: def create_empty() -> Database:
@@ -121,7 +104,7 @@ def create_empty() -> Database:
Returns: Returns:
Empty Database object Empty Database object
""" """
return Database(vectors=[], records={}) return Database(vectors=[], records={}, documents=[])
def load(database_file: Path) -> Database: def load(database_file: Path) -> Database:
@@ -151,8 +134,10 @@ def load(database_file: Path) -> Database:
# Records already use bytes as keys, so we can use them directly # Records already use bytes as keys, so we can use them directly
records = serializable_db['records'] records = serializable_db['records']
documents = serializable_db['documents']
return Database(vectors, records) return Database(vectors, records, documents)
def save(db: Database, database_file: Path) -> None: def save(db: Database, database_file: Path) -> None:
""" """
@@ -171,7 +156,8 @@ def save(db: Database, database_file: Path) -> None:
'vectors': [vector.tobytes() for vector in db.vectors], 'vectors': [vector.tobytes() for vector in db.vectors],
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64', 'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
'vector_shape': db.vectors[0].shape if db.vectors else (), 'vector_shape': db.vectors[0].shape if db.vectors else (),
'records': db.records # Already uses bytes as keys 'records': db.records, # Already uses bytes as keys
'documents': db.documents,
} }
# Save to file # Save to file
@@ -179,7 +165,7 @@ def save(db: Database, database_file: Path) -> None:
pickle.dump(serializable_db, f) pickle.dump(serializable_db, f)
def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[float, Record]]: def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]:
""" """
Query the database and return the N nearest records. Query the database and return the N nearest records.
@@ -204,7 +190,7 @@ def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[
nearest_results = _find_nearest(db.vectors, query_vector, record_count) nearest_results = _find_nearest(db.vectors, query_vector, record_count)
# Convert results to (distance, Record) tuples # Convert results to (distance, Record) tuples
results = [] results: list[QueryResult] = []
for distance, vector_idx in nearest_results: for distance, vector_idx in nearest_results:
# Get the vector at this index # Get the vector at this index
vector = db.vectors[vector_idx] vector = db.vectors[vector_idx]
@@ -213,7 +199,7 @@ def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[
# Look up the corresponding record # Look up the corresponding record
if vector_bytes in db.records: if vector_bytes in db.records:
record = db.records[vector_bytes] record = db.records[vector_bytes]
results.append((distance, record)) results.append(QueryResult(record, distance, db.documents[record.document_index]))
return results return results
@@ -237,7 +223,6 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
db = load(db) db = load(db)
save_to_file = True save_to_file = True
# Validate that the file exists and is a PDF
if not file.exists(): if not file.exists():
raise FileNotFoundError(f"File not found: {file}") raise FileNotFoundError(f"File not found: {file}")
@@ -246,82 +231,69 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
print(f"Processing PDF: {file}") print(f"Processing PDF: {file}")
# Open PDF with PyMuPDF document_index = len(db.documents)
try: try:
doc = pymupdf.open(file) doc = pymupdf.open(file)
print(f"PDF opened successfully: {len(doc)} pages") print(f"PDF opened successfully: {len(doc)} pages")
# Collect all chunks to process records: list[Record] = []
chunk_tasks = [] chunk_size = 1024
# Process each page to collect chunks
for page_num in range(len(doc)): for page_num in range(len(doc)):
page = doc[page_num] page = doc[page_num]
# Extract text from page
text = page.get_text().strip() text = page.get_text().strip()
# Skip empty pages
if not text: if not text:
print(f" Page {page_num + 1}: Skipped (empty)") print(f" Page {page_num + 1}: Skipped (empty)")
continue continue
#print(f" Page {page_num + 1}: {len(text)} characters")
# Split page text into chunks of 1024 characters
chunk_size = 1024
chunks = []
# Simple chunking - split text into chunks of specified size # Simple chunking - split text into chunks of specified size
for i in range(0, len(text), chunk_size): for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
chunk = text[i:i + chunk_size] chunk = text[i:i + chunk_size]
if chunk.strip(): # Only add non-empty chunks if chunk_stripped := chunk.strip(): # Only add non-empty chunks
chunks.append(chunk.strip()) # page_num + 1 for use friendliness
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
#print(f" Split into {len(chunks)} chunks")
# Add chunk tasks for parallel processing
for chunk_idx, chunk_text in enumerate(chunks):
chunk_tasks.append((file, page_num, chunk_idx, chunk_text))
doc.close() doc.close()
# Process chunks in parallel
print(f"Processing {len(chunk_tasks)} chunks with {max_workers} workers...")
# TODO measure with GIL disabled to check if multithreading actually helps
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all chunk processing tasks
future_to_chunk = {
executor.submit(_process_chunk, chunk_data): chunk_data
for chunk_data in chunk_tasks
}
# Collect results as they complete
completed_chunks = 0
for future in as_completed(future_to_chunk):
try:
vector, record = future.result()
# Convert vector to bytes for use as dictionary key
vector_bytes = vector.tobytes()
# Add to database
db.vectors.append(vector)
db.records[vector_bytes] = record
completed_chunks += 1
# if completed_chunks % 10 == 0 or completed_chunks == len(chunk_tasks):
# print(f" Completed {completed_chunks}/{len(chunk_tasks)} chunks")
except Exception as e:
chunk_data = future_to_chunk[future]
print(f" Error processing chunk {chunk_data}: {e}")
print(f"Successfully processed {file}: {len(chunk_tasks)} chunks")
except Exception as e: except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}") raise RuntimeError(f"Error processing PDF {file}: {e}")
# Process chunks in parallel
print(f"Processing {len(records)} chunks with {max_workers} workers...")
db.documents.append(file)
for record in records:
vector = _vectorize_record(record)
db.records[vector.tobytes()] = record
db.vectors.append(vector)
# TODO measure with GIL disabled to check if multithreading actually helps
# with ThreadPoolExecutor(max_workers=max_workers) as executor:
# # Submit all chunk processing tasks
# vector_futures = {
# executor.submit(_vectorize_record, r): r for r in records
# }
# # Collect results as they complete
# completed_chunks = 0
# for future in as_completed(vector_futures):
# try:
# vector = future.result()
# # Convert vector to bytes for use as dictionary key
# vector_bytes = vector.tobytes()
# # Add to database
# db.vectors.append(vector)
# db.records[vector_bytes] = record
# completed_chunks += 1
# # if completed_chunks % 10 == 0 or completed_chunks == len(chunk_tasks):
# # print(f" Completed {completed_chunks}/{len(chunk_tasks)} chunks")
# except Exception as e:
# chunk_data = future_to_chunk[future]
# print(f" Error processing chunk {chunk_data}: {e}")
print(f"Successfully processed {file}: {len(records)} chunks")
# Save database if we loaded it from file # Save database if we loaded it from file
if save_to_file and database_file_path: if save_to_file and database_file_path:

24
main.py
View File

@@ -164,12 +164,12 @@ def query(db_path: str, query_text: str):
print(f"\nFound {len(results)} results:") print(f"\nFound {len(results)} results:")
print("=" * 60) print("=" * 60)
for i, (distance, record) in enumerate(results, 1): for i, res in enumerate(results, 1):
print(f"\n{i}. Distance: {distance:.4f}") print(f"\n{i}. Distance: {res.distance:.4f}")
print(f" Document: {record.document.name}") print(f" Document: {res.document.name}")
print(f" Page: {record.page}, Chunk: {record.chunk}") print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
# Replace all whitespace characters with regular spaces for cleaner display # Replace all whitespace characters with regular spaces for cleaner display
clean_text = ' '.join(record.text[:200].split()) clean_text = ' '.join(res.record.text[:200].split())
print(f" Text preview: {clean_text}...") print(f" Text preview: {clean_text}...")
if i < len(results): if i < len(results):
print("-" * 40) print("-" * 40)
@@ -233,14 +233,14 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
# Format results for JSON response # Format results for JSON response
formatted_results = [] formatted_results = []
for distance, record in results: for res in results:
formatted_results.append({ formatted_results.append({
'distance': float(distance), 'distance': float(res.distance),
'document': record.document.name, 'document': res.document.name,
'document_path': str(record.document), # Full path for the link 'document_path': str(res.document), # Full path for the link
'page': record.page, 'page': res.record.page,
'chunk': record.chunk, 'chunk': res.record.chunk,
'text': ' '.join(record.text[:300].split()) # Clean and truncate text 'text': ' '.join(res.record.text[:300].split()) # Clean and truncate text
}) })
return jsonify({'results': formatted_results}) return jsonify({'results': formatted_results})