Multithread adding files (again)

This commit is contained in:
Jan Mrna
2025-11-05 14:41:26 +01:00
parent 9bc39ccea8
commit 491d79c617

40
db.py
View File

@@ -81,8 +81,8 @@ def _embed(text: str) -> Vector:
return np.array(ollama.embeddings(model=MODEL, prompt=text)["embedding"])
def _vectorize_record(record: Record) -> Vector:
return _embed(record.text)
def _vectorize_record(record: Record) -> tuple[Record, Vector]:
return record, _embed(record.text)
#
# High-level (exported) functions
@@ -260,41 +260,17 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
print(f"Processing {len(records)} chunks with {max_workers} workers...")
db.documents.append(file)
for record in records:
vector = _vectorize_record(record)
db.records[vector.tobytes()] = record
db.vectors.append(vector)
# TODO measure with GIL disabled to check if multithreading actually helps
# with ThreadPoolExecutor(max_workers=max_workers) as executor:
# # Submit all chunk processing tasks
# vector_futures = {
# executor.submit(_vectorize_record, r): r for r in records
# }
# # Collect results as they complete
# completed_chunks = 0
# for future in as_completed(vector_futures):
# try:
# vector = future.result()
# # Convert vector to bytes for use as dictionary key
# vector_bytes = vector.tobytes()
# # Add to database
# db.vectors.append(vector)
# db.records[vector_bytes] = record
# completed_chunks += 1
# # if completed_chunks % 10 == 0 or completed_chunks == len(chunk_tasks):
# # print(f" Completed {completed_chunks}/{len(chunk_tasks)} chunks")
# except Exception as e:
# chunk_data = future_to_chunk[future]
# print(f" Error processing chunk {chunk_data}: {e}")
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [ pool.submit(_vectorize_record, r) for r in records ]
for f in as_completed(futures):
record, vector = f.result()
db.records[vector.tobytes()] = record
db.vectors.append(vector)
print(f"Successfully processed {file}: {len(records)} chunks")
# Save database if we loaded it from file
if save_to_file and database_file_path:
save(db, database_file_path)