DB Record: use document index instead of Path
This commit is contained in:
162
db.py
162
db.py
@@ -17,11 +17,17 @@ type VectorBytes = bytes
|
|||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class Record:
|
class Record:
|
||||||
document: Path
|
document_index: int
|
||||||
page: int
|
page: int
|
||||||
text: str
|
text: str
|
||||||
chunk: int = 0 # Chunk number within the page (0-indexed)
|
chunk: int = 0 # Chunk number within the page (0-indexed)
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class QueryResult:
|
||||||
|
record: Record
|
||||||
|
distance: float
|
||||||
|
document: Path
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class Database:
|
class Database:
|
||||||
"""
|
"""
|
||||||
@@ -34,6 +40,7 @@ class Database:
|
|||||||
"""
|
"""
|
||||||
vectors: list[Vector]
|
vectors: list[Vector]
|
||||||
records: dict[VectorBytes, Record]
|
records: dict[VectorBytes, Record]
|
||||||
|
documents: list[Path]
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
@@ -74,32 +81,8 @@ def _embed(text: str) -> Vector:
|
|||||||
return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"])
|
return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"])
|
||||||
|
|
||||||
|
|
||||||
def _process_chunk(chunk_data: tuple) -> tuple[Vector, Record]:
|
def _vectorize_record(record: Record) -> Vector:
|
||||||
"""
|
return _embed(record.text)
|
||||||
Process a single chunk: generate embedding and create record.
|
|
||||||
Used for multithreading.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_data: Tuple of (file_path, page_num, chunk_idx, chunk_text)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (vector, record)
|
|
||||||
"""
|
|
||||||
file_path, page_num, chunk_idx, chunk_text = chunk_data
|
|
||||||
|
|
||||||
# Generate embedding vector for this chunk
|
|
||||||
vector = _embed(chunk_text)
|
|
||||||
|
|
||||||
# Create record for this chunk
|
|
||||||
record = Record(
|
|
||||||
document=file_path,
|
|
||||||
page=page_num + 1, # 1-indexed for user-friendliness
|
|
||||||
text=chunk_text,
|
|
||||||
chunk=chunk_idx + 1 # 1-indexed for user-friendliness
|
|
||||||
)
|
|
||||||
|
|
||||||
return vector, record
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# High-level (exported) functions
|
# High-level (exported) functions
|
||||||
@@ -109,9 +92,9 @@ def create_dummy() -> Database:
|
|||||||
db_length: Final[int] = 10
|
db_length: Final[int] = 10
|
||||||
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
|
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
|
||||||
records = {
|
records = {
|
||||||
vector.tobytes(): Record(Path("dummy"), 1, "Lorem my ipsum", 1) for vector in vectors
|
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
|
||||||
}
|
}
|
||||||
return Database(vectors, records)
|
return Database(vectors, records, [Path("dummy")])
|
||||||
|
|
||||||
|
|
||||||
def create_empty() -> Database:
|
def create_empty() -> Database:
|
||||||
@@ -121,7 +104,7 @@ def create_empty() -> Database:
|
|||||||
Returns:
|
Returns:
|
||||||
Empty Database object
|
Empty Database object
|
||||||
"""
|
"""
|
||||||
return Database(vectors=[], records={})
|
return Database(vectors=[], records={}, documents=[])
|
||||||
|
|
||||||
|
|
||||||
def load(database_file: Path) -> Database:
|
def load(database_file: Path) -> Database:
|
||||||
@@ -152,7 +135,9 @@ def load(database_file: Path) -> Database:
|
|||||||
# Records already use bytes as keys, so we can use them directly
|
# Records already use bytes as keys, so we can use them directly
|
||||||
records = serializable_db['records']
|
records = serializable_db['records']
|
||||||
|
|
||||||
return Database(vectors, records)
|
documents = serializable_db['documents']
|
||||||
|
|
||||||
|
return Database(vectors, records, documents)
|
||||||
|
|
||||||
def save(db: Database, database_file: Path) -> None:
|
def save(db: Database, database_file: Path) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -171,7 +156,8 @@ def save(db: Database, database_file: Path) -> None:
|
|||||||
'vectors': [vector.tobytes() for vector in db.vectors],
|
'vectors': [vector.tobytes() for vector in db.vectors],
|
||||||
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
|
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
|
||||||
'vector_shape': db.vectors[0].shape if db.vectors else (),
|
'vector_shape': db.vectors[0].shape if db.vectors else (),
|
||||||
'records': db.records # Already uses bytes as keys
|
'records': db.records, # Already uses bytes as keys
|
||||||
|
'documents': db.documents,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
@@ -179,7 +165,7 @@ def save(db: Database, database_file: Path) -> None:
|
|||||||
pickle.dump(serializable_db, f)
|
pickle.dump(serializable_db, f)
|
||||||
|
|
||||||
|
|
||||||
def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[float, Record]]:
|
def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]:
|
||||||
"""
|
"""
|
||||||
Query the database and return the N nearest records.
|
Query the database and return the N nearest records.
|
||||||
|
|
||||||
@@ -204,7 +190,7 @@ def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[
|
|||||||
nearest_results = _find_nearest(db.vectors, query_vector, record_count)
|
nearest_results = _find_nearest(db.vectors, query_vector, record_count)
|
||||||
|
|
||||||
# Convert results to (distance, Record) tuples
|
# Convert results to (distance, Record) tuples
|
||||||
results = []
|
results: list[QueryResult] = []
|
||||||
for distance, vector_idx in nearest_results:
|
for distance, vector_idx in nearest_results:
|
||||||
# Get the vector at this index
|
# Get the vector at this index
|
||||||
vector = db.vectors[vector_idx]
|
vector = db.vectors[vector_idx]
|
||||||
@@ -213,7 +199,7 @@ def query(db: Database | Path, text: str, record_count: int = 10) -> list[tuple[
|
|||||||
# Look up the corresponding record
|
# Look up the corresponding record
|
||||||
if vector_bytes in db.records:
|
if vector_bytes in db.records:
|
||||||
record = db.records[vector_bytes]
|
record = db.records[vector_bytes]
|
||||||
results.append((distance, record))
|
results.append(QueryResult(record, distance, db.documents[record.document_index]))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -237,7 +223,6 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
|||||||
db = load(db)
|
db = load(db)
|
||||||
save_to_file = True
|
save_to_file = True
|
||||||
|
|
||||||
# Validate that the file exists and is a PDF
|
|
||||||
if not file.exists():
|
if not file.exists():
|
||||||
raise FileNotFoundError(f"File not found: {file}")
|
raise FileNotFoundError(f"File not found: {file}")
|
||||||
|
|
||||||
@@ -246,83 +231,70 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
|||||||
|
|
||||||
print(f"Processing PDF: {file}")
|
print(f"Processing PDF: {file}")
|
||||||
|
|
||||||
# Open PDF with PyMuPDF
|
document_index = len(db.documents)
|
||||||
try:
|
try:
|
||||||
doc = pymupdf.open(file)
|
doc = pymupdf.open(file)
|
||||||
print(f"PDF opened successfully: {len(doc)} pages")
|
print(f"PDF opened successfully: {len(doc)} pages")
|
||||||
|
|
||||||
# Collect all chunks to process
|
records: list[Record] = []
|
||||||
chunk_tasks = []
|
chunk_size = 1024
|
||||||
|
|
||||||
# Process each page to collect chunks
|
|
||||||
for page_num in range(len(doc)):
|
for page_num in range(len(doc)):
|
||||||
page = doc[page_num]
|
page = doc[page_num]
|
||||||
|
|
||||||
# Extract text from page
|
|
||||||
text = page.get_text().strip()
|
text = page.get_text().strip()
|
||||||
|
|
||||||
# Skip empty pages
|
|
||||||
if not text:
|
if not text:
|
||||||
print(f" Page {page_num + 1}: Skipped (empty)")
|
print(f" Page {page_num + 1}: Skipped (empty)")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
#print(f" Page {page_num + 1}: {len(text)} characters")
|
|
||||||
|
|
||||||
# Split page text into chunks of 1024 characters
|
|
||||||
chunk_size = 1024
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
# Simple chunking - split text into chunks of specified size
|
# Simple chunking - split text into chunks of specified size
|
||||||
for i in range(0, len(text), chunk_size):
|
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
|
||||||
chunk = text[i:i + chunk_size]
|
chunk = text[i:i + chunk_size]
|
||||||
if chunk.strip(): # Only add non-empty chunks
|
if chunk_stripped := chunk.strip(): # Only add non-empty chunks
|
||||||
chunks.append(chunk.strip())
|
# page_num + 1 for use friendliness
|
||||||
|
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
|
||||||
#print(f" Split into {len(chunks)} chunks")
|
|
||||||
|
|
||||||
# Add chunk tasks for parallel processing
|
|
||||||
for chunk_idx, chunk_text in enumerate(chunks):
|
|
||||||
chunk_tasks.append((file, page_num, chunk_idx, chunk_text))
|
|
||||||
|
|
||||||
doc.close()
|
doc.close()
|
||||||
|
|
||||||
# Process chunks in parallel
|
|
||||||
print(f"Processing {len(chunk_tasks)} chunks with {max_workers} workers...")
|
|
||||||
|
|
||||||
# TODO measure with GIL disabled to check if multithreading actually helps
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
||||||
# Submit all chunk processing tasks
|
|
||||||
future_to_chunk = {
|
|
||||||
executor.submit(_process_chunk, chunk_data): chunk_data
|
|
||||||
for chunk_data in chunk_tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
# Collect results as they complete
|
|
||||||
completed_chunks = 0
|
|
||||||
for future in as_completed(future_to_chunk):
|
|
||||||
try:
|
|
||||||
vector, record = future.result()
|
|
||||||
|
|
||||||
# Convert vector to bytes for use as dictionary key
|
|
||||||
vector_bytes = vector.tobytes()
|
|
||||||
|
|
||||||
# Add to database
|
|
||||||
db.vectors.append(vector)
|
|
||||||
db.records[vector_bytes] = record
|
|
||||||
|
|
||||||
completed_chunks += 1
|
|
||||||
# if completed_chunks % 10 == 0 or completed_chunks == len(chunk_tasks):
|
|
||||||
# print(f" Completed {completed_chunks}/{len(chunk_tasks)} chunks")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
chunk_data = future_to_chunk[future]
|
|
||||||
print(f" Error processing chunk {chunk_data}: {e}")
|
|
||||||
|
|
||||||
print(f"Successfully processed {file}: {len(chunk_tasks)} chunks")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error processing PDF {file}: {e}")
|
raise RuntimeError(f"Error processing PDF {file}: {e}")
|
||||||
|
|
||||||
|
# Process chunks in parallel
|
||||||
|
print(f"Processing {len(records)} chunks with {max_workers} workers...")
|
||||||
|
|
||||||
|
db.documents.append(file)
|
||||||
|
for record in records:
|
||||||
|
vector = _vectorize_record(record)
|
||||||
|
db.records[vector.tobytes()] = record
|
||||||
|
db.vectors.append(vector)
|
||||||
|
|
||||||
|
# TODO measure with GIL disabled to check if multithreading actually helps
|
||||||
|
# with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# # Submit all chunk processing tasks
|
||||||
|
# vector_futures = {
|
||||||
|
# executor.submit(_vectorize_record, r): r for r in records
|
||||||
|
# }
|
||||||
|
|
||||||
|
# # Collect results as they complete
|
||||||
|
# completed_chunks = 0
|
||||||
|
# for future in as_completed(vector_futures):
|
||||||
|
# try:
|
||||||
|
# vector = future.result()
|
||||||
|
# # Convert vector to bytes for use as dictionary key
|
||||||
|
# vector_bytes = vector.tobytes()
|
||||||
|
# # Add to database
|
||||||
|
# db.vectors.append(vector)
|
||||||
|
# db.records[vector_bytes] = record
|
||||||
|
|
||||||
|
# completed_chunks += 1
|
||||||
|
# # if completed_chunks % 10 == 0 or completed_chunks == len(chunk_tasks):
|
||||||
|
# # print(f" Completed {completed_chunks}/{len(chunk_tasks)} chunks")
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# chunk_data = future_to_chunk[future]
|
||||||
|
# print(f" Error processing chunk {chunk_data}: {e}")
|
||||||
|
|
||||||
|
print(f"Successfully processed {file}: {len(records)} chunks")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Save database if we loaded it from file
|
# Save database if we loaded it from file
|
||||||
if save_to_file and database_file_path:
|
if save_to_file and database_file_path:
|
||||||
save(db, database_file_path)
|
save(db, database_file_path)
|
||||||
|
|||||||
24
main.py
24
main.py
@@ -164,12 +164,12 @@ def query(db_path: str, query_text: str):
|
|||||||
print(f"\nFound {len(results)} results:")
|
print(f"\nFound {len(results)} results:")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
for i, (distance, record) in enumerate(results, 1):
|
for i, res in enumerate(results, 1):
|
||||||
print(f"\n{i}. Distance: {distance:.4f}")
|
print(f"\n{i}. Distance: {res.distance:.4f}")
|
||||||
print(f" Document: {record.document.name}")
|
print(f" Document: {res.document.name}")
|
||||||
print(f" Page: {record.page}, Chunk: {record.chunk}")
|
print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
|
||||||
# Replace all whitespace characters with regular spaces for cleaner display
|
# Replace all whitespace characters with regular spaces for cleaner display
|
||||||
clean_text = ' '.join(record.text[:200].split())
|
clean_text = ' '.join(res.record.text[:200].split())
|
||||||
print(f" Text preview: {clean_text}...")
|
print(f" Text preview: {clean_text}...")
|
||||||
if i < len(results):
|
if i < len(results):
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
@@ -233,14 +233,14 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
|
|||||||
|
|
||||||
# Format results for JSON response
|
# Format results for JSON response
|
||||||
formatted_results = []
|
formatted_results = []
|
||||||
for distance, record in results:
|
for res in results:
|
||||||
formatted_results.append({
|
formatted_results.append({
|
||||||
'distance': float(distance),
|
'distance': float(res.distance),
|
||||||
'document': record.document.name,
|
'document': res.document.name,
|
||||||
'document_path': str(record.document), # Full path for the link
|
'document_path': str(res.document), # Full path for the link
|
||||||
'page': record.page,
|
'page': res.record.page,
|
||||||
'chunk': record.chunk,
|
'chunk': res.record.chunk,
|
||||||
'text': ' '.join(record.text[:300].split()) # Clean and truncate text
|
'text': ' '.join(res.record.text[:300].split()) # Clean and truncate text
|
||||||
})
|
})
|
||||||
|
|
||||||
return jsonify({'results': formatted_results})
|
return jsonify({'results': formatted_results})
|
||||||
|
|||||||
Reference in New Issue
Block a user