Format document
This commit is contained in:
136
db.py
136
db.py
@@ -6,15 +6,16 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pymupdf
|
import pymupdf
|
||||||
import ollama # TODO split to another file
|
import ollama # TODO split to another file
|
||||||
|
|
||||||
#
|
#
|
||||||
# Types
|
# Types
|
||||||
#
|
#
|
||||||
|
|
||||||
type Vector = np.NDArray # np.NDArray[np.float32] ?
|
type Vector = np.NDArray # np.NDArray[np.float32] ?
|
||||||
type VectorBytes = bytes
|
type VectorBytes = bytes
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class Record:
|
class Record:
|
||||||
document_index: int
|
document_index: int
|
||||||
@@ -22,12 +23,14 @@ class Record:
|
|||||||
text: str
|
text: str
|
||||||
chunk: int = 0 # Chunk number within the page (0-indexed)
|
chunk: int = 0 # Chunk number within the page (0-indexed)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class QueryResult:
|
class QueryResult:
|
||||||
record: Record
|
record: Record
|
||||||
distance: float
|
distance: float
|
||||||
document_name: str
|
document_name: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class Database:
|
class Database:
|
||||||
"""
|
"""
|
||||||
@@ -36,41 +39,45 @@ class Database:
|
|||||||
TODO For faster nearest neighbour lookup we should use something else,
|
TODO For faster nearest neighbour lookup we should use something else,
|
||||||
e.g. kd-trees
|
e.g. kd-trees
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vectors: list[Vector]
|
vectors: list[Vector]
|
||||||
records: dict[VectorBytes, Record]
|
records: dict[VectorBytes, Record]
|
||||||
documents: list[Path]
|
documents: list[Path]
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Internal functions
|
# Internal functions
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def _find_nearest(vectors_db: list[Vector], query_vector: Vector, count: int = 10) -> list[tuple[float, int]]:
|
def _find_nearest(
|
||||||
|
vectors_db: list[Vector], query_vector: Vector, count: int = 10
|
||||||
|
) -> list[tuple[float, int]]:
|
||||||
"""
|
"""
|
||||||
Find the N nearest vectors to the query embedding.
|
Find the N nearest vectors to the query embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vectors_db: List of vectors in the database
|
vectors_db: List of vectors in the database
|
||||||
query_vector: Query embedding vector
|
query_vector: Query embedding vector
|
||||||
n: Number of nearest neighbors to return
|
n: Number of nearest neighbors to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (distance, index) tuples, sorted by distance (closest first)
|
List of (distance, index) tuples, sorted by distance (closest first)
|
||||||
"""
|
"""
|
||||||
distances = [np.linalg.norm(x - query_vector) for x in vectors_db]
|
distances = [np.linalg.norm(x - query_vector) for x in vectors_db]
|
||||||
|
|
||||||
# Get indices sorted by distance
|
# Get indices sorted by distance
|
||||||
sorted_indices = np.argsort(distances)
|
sorted_indices = np.argsort(distances)
|
||||||
|
|
||||||
# Return top N results as (distance, index) tuples
|
# Return top N results as (distance, index) tuples
|
||||||
results = []
|
results = []
|
||||||
for i in range(min(count, len(sorted_indices))):
|
for i in range(min(count, len(sorted_indices))):
|
||||||
idx = sorted_indices[i]
|
idx = sorted_indices[i]
|
||||||
results.append((float(distances[idx]), int(idx)))
|
results.append((float(distances[idx]), int(idx)))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _embed(text: str) -> Vector:
|
def _embed(text: str) -> Vector:
|
||||||
"""
|
"""
|
||||||
Generate embedding vector for given text.
|
Generate embedding vector for given text.
|
||||||
@@ -82,13 +89,15 @@ def _embed(text: str) -> Vector:
|
|||||||
def _vectorize_record(record: Record) -> tuple[Record, Vector]:
|
def _vectorize_record(record: Record) -> tuple[Record, Vector]:
|
||||||
return record, _embed(record.text)
|
return record, _embed(record.text)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# High-level (exported) functions
|
# High-level (exported) functions
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def create_dummy() -> Database:
|
def create_dummy() -> Database:
|
||||||
db_length: Final[int] = 10
|
db_length: Final[int] = 10
|
||||||
vectors = [np.array([i, 2*i, 3*i, 4*i]) for i in range(db_length)]
|
vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
|
||||||
records = {
|
records = {
|
||||||
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
|
vector.tobytes(): Record(0, 1, "Lorem my ipsum", 1) for vector in vectors
|
||||||
}
|
}
|
||||||
@@ -98,7 +107,7 @@ def create_dummy() -> Database:
|
|||||||
def create_empty() -> Database:
|
def create_empty() -> Database:
|
||||||
"""
|
"""
|
||||||
Creates a new empty database with no vectors or records.
|
Creates a new empty database with no vectors or records.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Empty Database object
|
Empty Database object
|
||||||
"""
|
"""
|
||||||
@@ -108,105 +117,109 @@ def create_empty() -> Database:
|
|||||||
def load(database_file: Path) -> Database:
|
def load(database_file: Path) -> Database:
|
||||||
"""
|
"""
|
||||||
Loads a database from the given file.
|
Loads a database from the given file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
database_file: Path to the database file
|
database_file: Path to the database file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Database object loaded from file
|
Database object loaded from file
|
||||||
"""
|
"""
|
||||||
if not database_file.exists():
|
if not database_file.exists():
|
||||||
raise FileNotFoundError(f"Database file not found: {database_file}")
|
raise FileNotFoundError(f"Database file not found: {database_file}")
|
||||||
|
|
||||||
with open(database_file, 'rb') as f:
|
with open(database_file, "rb") as f:
|
||||||
serializable_db = pickle.load(f)
|
serializable_db = pickle.load(f)
|
||||||
|
|
||||||
# Reconstruct vectors from bytes
|
# Reconstruct vectors from bytes
|
||||||
vectors = []
|
vectors = []
|
||||||
vector_dtype = np.dtype(serializable_db.get('vector_dtype', 'float64'))
|
vector_dtype = np.dtype(serializable_db.get("vector_dtype", "float64"))
|
||||||
vector_shape = serializable_db.get('vector_shape', ())
|
vector_shape = serializable_db.get("vector_shape", ())
|
||||||
|
|
||||||
for vector_bytes in serializable_db['vectors']:
|
for vector_bytes in serializable_db["vectors"]:
|
||||||
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
|
vector = np.frombuffer(vector_bytes, dtype=vector_dtype).reshape(vector_shape)
|
||||||
vectors.append(vector)
|
vectors.append(vector)
|
||||||
|
|
||||||
# Records already use bytes as keys, so we can use them directly
|
|
||||||
records = serializable_db['records']
|
|
||||||
|
|
||||||
documents = serializable_db['documents']
|
# Records already use bytes as keys, so we can use them directly
|
||||||
|
records = serializable_db["records"]
|
||||||
|
|
||||||
|
documents = serializable_db["documents"]
|
||||||
|
|
||||||
return Database(vectors, records, documents)
|
return Database(vectors, records, documents)
|
||||||
|
|
||||||
|
|
||||||
def save(db: Database, database_file: Path) -> None:
|
def save(db: Database, database_file: Path) -> None:
|
||||||
"""
|
"""
|
||||||
Saves the database to a file using pickle serialization.
|
Saves the database to a file using pickle serialization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: The Database object to save
|
db: The Database object to save
|
||||||
database_file: Path where to save the database file
|
database_file: Path where to save the database file
|
||||||
"""
|
"""
|
||||||
# Ensure the directory exists
|
# Ensure the directory exists
|
||||||
database_file.parent.mkdir(parents=True, exist_ok=True)
|
database_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Create a serializable version of the database
|
# Create a serializable version of the database
|
||||||
# Records already use bytes as keys, so we can use them directly
|
# Records already use bytes as keys, so we can use them directly
|
||||||
serializable_db = {
|
serializable_db = {
|
||||||
'vectors': [vector.tobytes() for vector in db.vectors],
|
"vectors": [vector.tobytes() for vector in db.vectors],
|
||||||
'vector_dtype': str(db.vectors[0].dtype) if db.vectors else 'float64',
|
"vector_dtype": str(db.vectors[0].dtype) if db.vectors else "float64",
|
||||||
'vector_shape': db.vectors[0].shape if db.vectors else (),
|
"vector_shape": db.vectors[0].shape if db.vectors else (),
|
||||||
'records': db.records, # Already uses bytes as keys
|
"records": db.records, # Already uses bytes as keys
|
||||||
'documents': db.documents,
|
"documents": db.documents,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
with open(database_file, 'wb') as f:
|
with open(database_file, "wb") as f:
|
||||||
pickle.dump(serializable_db, f)
|
pickle.dump(serializable_db, f)
|
||||||
|
|
||||||
|
|
||||||
def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]:
|
def query(db: Database | Path, text: str, record_count: int = 10) -> list[QueryResult]:
|
||||||
"""
|
"""
|
||||||
Query the database and return the N nearest records.
|
Query the database and return the N nearest records.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database object or path to database file
|
db: Database object or path to database file
|
||||||
text: Query text to search for
|
text: Query text to search for
|
||||||
record_count: Number of nearest neighbors to return (default: 10)
|
record_count: Number of nearest neighbors to return (default: 10)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (distance, Record) tuples, sorted by distance (closest first)
|
List of (distance, Record) tuples, sorted by distance (closest first)
|
||||||
"""
|
"""
|
||||||
if isinstance(db, Path):
|
if isinstance(db, Path):
|
||||||
db = load(db)
|
db = load(db)
|
||||||
|
|
||||||
# Generate embedding for query text
|
# Generate embedding for query text
|
||||||
query_vector = _embed(text)
|
query_vector = _embed(text)
|
||||||
|
|
||||||
# Find nearest vectors
|
# Find nearest vectors
|
||||||
# NOTE We're using euclidean distance as a metric of similarity,
|
# NOTE We're using euclidean distance as a metric of similarity,
|
||||||
# there are some alternatives (cos or dot product) which may be used.
|
# there are some alternatives (cos or dot product) which may be used.
|
||||||
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning)
|
# See https://en.wikipedia.org/wiki/Embedding_(machine_learning)
|
||||||
nearest_results = _find_nearest(db.vectors, query_vector, record_count)
|
nearest_results = _find_nearest(db.vectors, query_vector, record_count)
|
||||||
|
|
||||||
# Convert results to (distance, Record) tuples
|
# Convert results to (distance, Record) tuples
|
||||||
results: list[QueryResult] = []
|
results: list[QueryResult] = []
|
||||||
for distance, vector_idx in nearest_results:
|
for distance, vector_idx in nearest_results:
|
||||||
# Get the vector at this index
|
# Get the vector at this index
|
||||||
vector = db.vectors[vector_idx]
|
vector = db.vectors[vector_idx]
|
||||||
vector_bytes = vector.tobytes()
|
vector_bytes = vector.tobytes()
|
||||||
|
|
||||||
# Look up the corresponding record
|
# Look up the corresponding record
|
||||||
if vector_bytes in db.records:
|
if vector_bytes in db.records:
|
||||||
record = db.records[vector_bytes]
|
record = db.records[vector_bytes]
|
||||||
results.append(QueryResult(record, distance, db.documents[record.document_index].name))
|
results.append(
|
||||||
|
QueryResult(record, distance, db.documents[record.document_index].name)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a new document to the database. If path is given, do load, add, save.
|
Adds a new document to the database. If path is given, do load, add, save.
|
||||||
Loads PDF with PyMuPDF, splits by pages, and creates records and vectors.
|
Loads PDF with PyMuPDF, splits by pages, and creates records and vectors.
|
||||||
Uses multithreading for embedding generation.
|
Uses multithreading for embedding generation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database object or path to database file
|
db: Database object or path to database file
|
||||||
file: Path to PDF file to add
|
file: Path to PDF file to add
|
||||||
@@ -215,25 +228,25 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
|||||||
"""
|
"""
|
||||||
save_to_file = False
|
save_to_file = False
|
||||||
database_file_path = None
|
database_file_path = None
|
||||||
|
|
||||||
if isinstance(db, Path):
|
if isinstance(db, Path):
|
||||||
database_file_path = db
|
database_file_path = db
|
||||||
db = load(db)
|
db = load(db)
|
||||||
save_to_file = True
|
save_to_file = True
|
||||||
|
|
||||||
if not file.exists():
|
if not file.exists():
|
||||||
raise FileNotFoundError(f"File not found: {file}")
|
raise FileNotFoundError(f"File not found: {file}")
|
||||||
|
|
||||||
if file.suffix.lower() != '.pdf':
|
if file.suffix.lower() != ".pdf":
|
||||||
raise ValueError(f"File must be a PDF: {file}")
|
raise ValueError(f"File must be a PDF: {file}")
|
||||||
|
|
||||||
print(f"Processing PDF: {file}")
|
print(f"Processing PDF: {file}")
|
||||||
|
|
||||||
document_index = len(db.documents)
|
document_index = len(db.documents)
|
||||||
try:
|
try:
|
||||||
doc = pymupdf.open(file)
|
doc = pymupdf.open(file)
|
||||||
print(f"PDF opened successfully: {len(doc)} pages")
|
print(f"PDF opened successfully: {len(doc)} pages")
|
||||||
|
|
||||||
records: list[Record] = []
|
records: list[Record] = []
|
||||||
chunk_size = 1024
|
chunk_size = 1024
|
||||||
|
|
||||||
@@ -243,13 +256,15 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
|||||||
if not text:
|
if not text:
|
||||||
print(f" Page {page_num + 1}: Skipped (empty)")
|
print(f" Page {page_num + 1}: Skipped (empty)")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Simple chunking - split text into chunks of specified size
|
# Simple chunking - split text into chunks of specified size
|
||||||
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
|
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
|
||||||
chunk = text[i:i + chunk_size]
|
chunk = text[i : i + chunk_size]
|
||||||
if chunk_stripped := chunk.strip(): # Only add non-empty chunks
|
if chunk_stripped := chunk.strip(): # Only add non-empty chunks
|
||||||
# page_num + 1 for use friendliness
|
# page_num + 1 for use friendliness
|
||||||
records.append(Record(document_index, page_num + 1, chunk_stripped, chunk_idx))
|
records.append(
|
||||||
|
Record(document_index, page_num + 1, chunk_stripped, chunk_idx)
|
||||||
|
)
|
||||||
doc.close()
|
doc.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error processing PDF {file}: {e}")
|
raise RuntimeError(f"Error processing PDF {file}: {e}")
|
||||||
@@ -261,34 +276,35 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
|||||||
|
|
||||||
# TODO measure with GIL disabled to check if multithreading actually helps
|
# TODO measure with GIL disabled to check if multithreading actually helps
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
futures = [ pool.submit(_vectorize_record, r) for r in records ]
|
futures = [pool.submit(_vectorize_record, r) for r in records]
|
||||||
for f in as_completed(futures):
|
for f in as_completed(futures):
|
||||||
record, vector = f.result()
|
record, vector = f.result()
|
||||||
db.records[vector.tobytes()] = record
|
db.records[vector.tobytes()] = record
|
||||||
db.vectors.append(vector)
|
db.vectors.append(vector)
|
||||||
|
|
||||||
print(f"Successfully processed {file}: {len(records)} chunks")
|
print(f"Successfully processed {file}: {len(records)} chunks")
|
||||||
|
|
||||||
# Save database if we loaded it from file
|
# Save database if we loaded it from file
|
||||||
if save_to_file and database_file_path:
|
if save_to_file and database_file_path:
|
||||||
save(db, database_file_path)
|
save(db, database_file_path)
|
||||||
print(f"Database saved to {database_file_path}")
|
print(f"Database saved to {database_file_path}")
|
||||||
|
|
||||||
|
|
||||||
def get_document_path(db: Database | Path, document_index: int) -> Path:
|
def get_document_path(db: Database | Path, document_index: int) -> Path:
|
||||||
"""
|
"""
|
||||||
Get the file path of the document at the given index in the database.
|
Get the file path of the document at the given index in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database object or path to database file
|
db: Database object or path to database file
|
||||||
document_index: Index of the document to retrieve
|
document_index: Index of the document to retrieve
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the document file
|
Path to the document file
|
||||||
"""
|
"""
|
||||||
if isinstance(db, Path):
|
if isinstance(db, Path):
|
||||||
db = load(db)
|
db = load(db)
|
||||||
|
|
||||||
if document_index < 0 or document_index >= len(db.documents):
|
if document_index < 0 or document_index >= len(db.documents):
|
||||||
raise IndexError(f"Document index out of range: {document_index}")
|
raise IndexError(f"Document index out of range: {document_index}")
|
||||||
|
|
||||||
return db.documents[document_index]
|
return db.documents[document_index]
|
||||||
|
|||||||
240
main.py
240
main.py
@@ -9,56 +9,69 @@ import db
|
|||||||
|
|
||||||
DEFAULT_DB_PATH: Final[Path] = Path("db.pkl")
|
DEFAULT_DB_PATH: Final[Path] = Path("db.pkl")
|
||||||
|
|
||||||
|
|
||||||
def test_database():
|
def test_database():
|
||||||
"""Test database save/load functionality by creating, saving, loading and comparing."""
|
"""Test database save/load functionality by creating, saving, loading and comparing."""
|
||||||
print("=== Database Test ===")
|
print("=== Database Test ===")
|
||||||
|
|
||||||
# Create dummy database
|
# Create dummy database
|
||||||
print("1. Creating dummy database...")
|
print("1. Creating dummy database...")
|
||||||
original_db = db.create_dummy()
|
original_db = db.create_dummy()
|
||||||
print(f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records")
|
print(
|
||||||
|
f" Original DB: {len(original_db.vectors)} vectors, {len(original_db.records)} records"
|
||||||
|
)
|
||||||
|
|
||||||
# Print some details about the original database
|
# Print some details about the original database
|
||||||
print(" First vector shape:", original_db.vectors[0].shape if original_db.vectors else "No vectors")
|
print(
|
||||||
print(" Sample vector:", original_db.vectors[0][:4] if original_db.vectors else "No vectors")
|
" First vector shape:",
|
||||||
|
original_db.vectors[0].shape if original_db.vectors else "No vectors",
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
" Sample vector:",
|
||||||
|
original_db.vectors[0][:4] if original_db.vectors else "No vectors",
|
||||||
|
)
|
||||||
print(" Sample record keys (first 3):", list(original_db.records.keys())[:3])
|
print(" Sample record keys (first 3):", list(original_db.records.keys())[:3])
|
||||||
|
|
||||||
# Create temporary file for testing
|
# Create temporary file for testing
|
||||||
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file:
|
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file:
|
||||||
test_file = Path(tmp_file.name)
|
test_file = Path(tmp_file.name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Save database
|
# Save database
|
||||||
print(f"\n2. Saving database to {test_file}...")
|
print(f"\n2. Saving database to {test_file}...")
|
||||||
db.save(original_db, test_file)
|
db.save(original_db, test_file)
|
||||||
print(f" File size: {test_file.stat().st_size} bytes")
|
print(f" File size: {test_file.stat().st_size} bytes")
|
||||||
|
|
||||||
# Load database
|
# Load database
|
||||||
print(f"\n3. Loading database from {test_file}...")
|
print(f"\n3. Loading database from {test_file}...")
|
||||||
loaded_db = db.load(test_file)
|
loaded_db = db.load(test_file)
|
||||||
print(f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records")
|
print(
|
||||||
|
f" Loaded DB: {len(loaded_db.vectors)} vectors, {len(loaded_db.records)} records"
|
||||||
|
)
|
||||||
|
|
||||||
# Compare databases
|
# Compare databases
|
||||||
print("\n4. Comparing original vs loaded...")
|
print("\n4. Comparing original vs loaded...")
|
||||||
|
|
||||||
# Check vector count
|
# Check vector count
|
||||||
vectors_match = len(original_db.vectors) == len(loaded_db.vectors)
|
vectors_match = len(original_db.vectors) == len(loaded_db.vectors)
|
||||||
print(f" Vector count match: {vectors_match}")
|
print(f" Vector count match: {vectors_match}")
|
||||||
|
|
||||||
# Check record count
|
# Check record count
|
||||||
records_match = len(original_db.records) == len(loaded_db.records)
|
records_match = len(original_db.records) == len(loaded_db.records)
|
||||||
print(f" Record count match: {records_match}")
|
print(f" Record count match: {records_match}")
|
||||||
|
|
||||||
# Check vector equality
|
# Check vector equality
|
||||||
vectors_equal = True
|
vectors_equal = True
|
||||||
if vectors_match and original_db.vectors:
|
if vectors_match and original_db.vectors:
|
||||||
for i, (orig, loaded) in enumerate(zip(original_db.vectors, loaded_db.vectors)):
|
for i, (orig, loaded) in enumerate(
|
||||||
|
zip(original_db.vectors, loaded_db.vectors)
|
||||||
|
):
|
||||||
if not np.array_equal(orig, loaded):
|
if not np.array_equal(orig, loaded):
|
||||||
vectors_equal = False
|
vectors_equal = False
|
||||||
print(f" Vector {i} mismatch!")
|
print(f" Vector {i} mismatch!")
|
||||||
break
|
break
|
||||||
print(f" All vectors equal: {vectors_equal}")
|
print(f" All vectors equal: {vectors_equal}")
|
||||||
|
|
||||||
# Check record equality
|
# Check record equality
|
||||||
records_equal = True
|
records_equal = True
|
||||||
if records_match:
|
if records_match:
|
||||||
@@ -72,21 +85,31 @@ def test_database():
|
|||||||
print(" Record content mismatch!")
|
print(" Record content mismatch!")
|
||||||
break
|
break
|
||||||
print(f" All records equal: {records_equal}")
|
print(f" All records equal: {records_equal}")
|
||||||
|
|
||||||
# Test embedding functionality
|
# Test embedding functionality
|
||||||
print("\n5. Testing embedding functionality (Ollama API server)...")
|
print("\n5. Testing embedding functionality (Ollama API server)...")
|
||||||
try:
|
try:
|
||||||
test_embedding = db._embed("This is a test text for embedding.")
|
test_embedding = db._embed("This is a test text for embedding.")
|
||||||
print(f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}")
|
print(
|
||||||
|
f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}"
|
||||||
|
)
|
||||||
ollama_running = True
|
ollama_running = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Embedding test FAILED: {e}\n Did you start ollama docker image?")
|
print(
|
||||||
|
f" Embedding test FAILED: {e}\n Did you start ollama docker image?"
|
||||||
|
)
|
||||||
ollama_running = False
|
ollama_running = False
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
all_good = vectors_match and records_match and vectors_equal and records_equal and ollama_running
|
all_good = (
|
||||||
|
vectors_match
|
||||||
|
and records_match
|
||||||
|
and vectors_equal
|
||||||
|
and records_equal
|
||||||
|
and ollama_running
|
||||||
|
)
|
||||||
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
|
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up temporary file
|
# Clean up temporary file
|
||||||
if test_file.exists():
|
if test_file.exists():
|
||||||
@@ -97,20 +120,20 @@ def test_database():
|
|||||||
def create_database(db_path: str):
|
def create_database(db_path: str):
|
||||||
"""Create a new empty database."""
|
"""Create a new empty database."""
|
||||||
db_file = Path(db_path)
|
db_file = Path(db_path)
|
||||||
|
|
||||||
# Check if file already exists
|
# Check if file already exists
|
||||||
if db_file.exists():
|
if db_file.exists():
|
||||||
response = input(f"Database {db_file} already exists. Overwrite? (y/N): ")
|
response = input(f"Database {db_file} already exists. Overwrite? (y/N): ")
|
||||||
if response.lower() != 'y':
|
if response.lower() != "y":
|
||||||
print("Operation cancelled.")
|
print("Operation cancelled.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create empty database
|
# Create empty database
|
||||||
empty_db = db.create_empty()
|
empty_db = db.create_empty()
|
||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
db.save(empty_db, db_file)
|
db.save(empty_db, db_file)
|
||||||
|
|
||||||
print(f"✅ Created empty database: {db_file}")
|
print(f"✅ Created empty database: {db_file}")
|
||||||
print(f" Vectors: {len(empty_db.vectors)}")
|
print(f" Vectors: {len(empty_db.vectors)}")
|
||||||
print(f" Records: {len(empty_db.records)}")
|
print(f" Records: {len(empty_db.records)}")
|
||||||
@@ -119,11 +142,11 @@ def create_database(db_path: str):
|
|||||||
def add_file(db_path: str, file_paths: list[str]):
|
def add_file(db_path: str, file_paths: list[str]):
|
||||||
"""Add one or more files to the semantic search database."""
|
"""Add one or more files to the semantic search database."""
|
||||||
print(f"Adding {len(file_paths)} file(s) to database: {db_path}")
|
print(f"Adding {len(file_paths)} file(s) to database: {db_path}")
|
||||||
|
|
||||||
db_file = Path(db_path)
|
db_file = Path(db_path)
|
||||||
successful_files = []
|
successful_files = []
|
||||||
failed_files = []
|
failed_files = []
|
||||||
|
|
||||||
for i, file_path in enumerate(file_paths, 1):
|
for i, file_path in enumerate(file_paths, 1):
|
||||||
print(f"\n[{i}/{len(file_paths)}] Processing: {file_path}")
|
print(f"\n[{i}/{len(file_paths)}] Processing: {file_path}")
|
||||||
try:
|
try:
|
||||||
@@ -133,7 +156,7 @@ def add_file(db_path: str, file_paths: list[str]):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
failed_files.append((file_path, str(e)))
|
failed_files.append((file_path, str(e)))
|
||||||
print(f"❌ Failed to add {file_path}: {e}")
|
print(f"❌ Failed to add {file_path}: {e}")
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print("SUMMARY:")
|
print("SUMMARY:")
|
||||||
@@ -141,44 +164,43 @@ def add_file(db_path: str, file_paths: list[str]):
|
|||||||
if successful_files:
|
if successful_files:
|
||||||
for file_path in successful_files:
|
for file_path in successful_files:
|
||||||
print(f" - {Path(file_path).name}")
|
print(f" - {Path(file_path).name}")
|
||||||
|
|
||||||
if failed_files:
|
if failed_files:
|
||||||
print(f"❌ Failed to add: {len(failed_files)} files")
|
print(f"❌ Failed to add: {len(failed_files)} files")
|
||||||
for file_path, error in failed_files:
|
for file_path, error in failed_files:
|
||||||
print(f" - {Path(file_path).name}: {error}")
|
print(f" - {Path(file_path).name}: {error}")
|
||||||
|
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
|
||||||
def query(db_path: str, query_text: str):
|
def query(db_path: str, query_text: str):
|
||||||
"""Query the semantic search database."""
|
"""Query the semantic search database."""
|
||||||
print(f"Querying: '{query_text}' in database: {db_path}")
|
print(f"Querying: '{query_text}' in database: {db_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = db.query(Path(db_path), query_text)
|
results = db.query(Path(db_path), query_text)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
print("No results found.")
|
print("No results found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"\nFound {len(results)} results:")
|
print(f"\nFound {len(results)} results:")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
for i, res in enumerate(results, 1):
|
for i, res in enumerate(results, 1):
|
||||||
print(f"\n{i}. Distance: {res.distance:.4f}")
|
print(f"\n{i}. Distance: {res.distance:.4f}")
|
||||||
print(f" Document: {res.document_name}")
|
print(f" Document: {res.document_name}")
|
||||||
print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
|
print(f" Page: {res.record.page}, Chunk: {res.record.chunk}")
|
||||||
# Replace all whitespace characters with regular spaces for cleaner display
|
# Replace all whitespace characters with regular spaces for cleaner display
|
||||||
clean_text = ' '.join(res.record.text[:200].split())
|
clean_text = " ".join(res.record.text[:200].split())
|
||||||
print(f" Text preview: {clean_text}...")
|
print(f" Text preview: {clean_text}...")
|
||||||
if i < len(results):
|
if i < len(results):
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error querying database: {e}")
|
print(f"Error querying database: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
|
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
|
||||||
"""Start a web server for the semantic search tool."""
|
"""Start a web server for the semantic search tool."""
|
||||||
try:
|
try:
|
||||||
@@ -190,63 +212,67 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
|
|||||||
# Set template_folder to 'templates' directory
|
# Set template_folder to 'templates' directory
|
||||||
app = Flask(__name__, template_folder="templates")
|
app = Flask(__name__, template_folder="templates")
|
||||||
db_file = Path(db_path)
|
db_file = Path(db_path)
|
||||||
|
|
||||||
# Check if database exists
|
# Check if database exists
|
||||||
if not db_file.exists():
|
if not db_file.exists():
|
||||||
print(f"❌ Database file not found: {db_file}")
|
print(f"❌ Database file not found: {db_file}")
|
||||||
print(" Create a database first using: python main.py create")
|
print(" Create a database first using: python main.py create")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@app.route('/')
|
@app.route("/")
|
||||||
def index():
|
def index():
|
||||||
return render_template("index.html", results=None)
|
return render_template("index.html", results=None)
|
||||||
|
|
||||||
@app.route('/file/<int:document_index>')
|
@app.route("/file/<int:document_index>")
|
||||||
def serve_file(document_index):
|
def serve_file(document_index):
|
||||||
"""Serve PDF files directly."""
|
"""Serve PDF files directly."""
|
||||||
try:
|
try:
|
||||||
file_path = db.get_document_path(db_file, document_index)
|
file_path = db.get_document_path(db_file, document_index)
|
||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
return jsonify({'error': 'File not found'}), 404
|
return jsonify({"error": "File not found"}), 404
|
||||||
return send_file(file_path, as_attachment=False)
|
return send_file(file_path, as_attachment=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({'error': str(e)}), 500
|
return jsonify({"error": str(e)}), 500
|
||||||
|
|
||||||
@app.route('/api/search', methods=['POST'])
|
@app.route("/api/search", methods=["POST"])
|
||||||
def search():
|
def search():
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or 'query' not in data:
|
if not data or "query" not in data:
|
||||||
return jsonify({'error': 'Missing query parameter'}), 400
|
return jsonify({"error": "Missing query parameter"}), 400
|
||||||
|
|
||||||
query_text = data['query'].strip()
|
query_text = data["query"].strip()
|
||||||
if not query_text:
|
if not query_text:
|
||||||
return jsonify({'error': 'Query cannot be empty'}), 400
|
return jsonify({"error": "Query cannot be empty"}), 400
|
||||||
|
|
||||||
# Perform the search
|
# Perform the search
|
||||||
results = db.query(db_file, query_text)
|
results = db.query(db_file, query_text)
|
||||||
|
|
||||||
# Format results for JSON response
|
# Format results for JSON response
|
||||||
formatted_results = []
|
formatted_results = []
|
||||||
for res in results:
|
for res in results:
|
||||||
formatted_results.append({
|
formatted_results.append(
|
||||||
'distance': float(res.distance),
|
{
|
||||||
'document_name': res.document_name,
|
"distance": float(res.distance),
|
||||||
'document_index': res.record.document_index,
|
"document_name": res.document_name,
|
||||||
'page': res.record.page,
|
"document_index": res.record.document_index,
|
||||||
'chunk': res.record.chunk,
|
"page": res.record.page,
|
||||||
'text': ' '.join(res.record.text[:300].split()) # Clean and truncate text
|
"chunk": res.record.chunk,
|
||||||
})
|
"text": " ".join(
|
||||||
return jsonify({'results': formatted_results})
|
res.record.text[:300].split()
|
||||||
|
), # Clean and truncate text
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return jsonify({"results": formatted_results})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({'error': str(e)}), 500
|
return jsonify({"error": str(e)}), 500
|
||||||
|
|
||||||
print("🚀 Starting web server...")
|
print("🚀 Starting web server...")
|
||||||
print(f" Database: {db_file}")
|
print(f" Database: {db_file}")
|
||||||
print(f" URL: http://{host}:{port}")
|
print(f" URL: http://{host}:{port}")
|
||||||
print(" Press Ctrl+C to stop")
|
print(" Press Ctrl+C to stop")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
app.run(host=host, port=port, debug=False)
|
app.run(host=host, port=port, debug=False)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -258,49 +284,71 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Semantic Search Tool",
|
description="Semantic Search Tool",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create subparsers for different commands
|
# Create subparsers for different commands
|
||||||
subparsers = parser.add_subparsers(dest='command', help='Available commands')
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
# Create command
|
# Create command
|
||||||
create_parser = subparsers.add_parser('create', aliases=['c'], help='Create a new empty database')
|
create_parser = subparsers.add_parser(
|
||||||
create_parser.add_argument('db_path', nargs='?', default=str(DEFAULT_DB_PATH),
|
"create", aliases=["c"], help="Create a new empty database"
|
||||||
help=f'Path to database file (default: {DEFAULT_DB_PATH})')
|
)
|
||||||
|
create_parser.add_argument(
|
||||||
|
"db_path",
|
||||||
|
nargs="?",
|
||||||
|
default=str(DEFAULT_DB_PATH),
|
||||||
|
help=f"Path to database file (default: {DEFAULT_DB_PATH})",
|
||||||
|
)
|
||||||
|
|
||||||
# Add file command
|
# Add file command
|
||||||
add_parser = subparsers.add_parser('add-file', aliases=['a'], help='Add one or more files to the search database')
|
add_parser = subparsers.add_parser(
|
||||||
add_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
|
"add-file", aliases=["a"], help="Add one or more files to the search database"
|
||||||
add_parser.add_argument('file_paths', nargs='+', help='Path(s) to the PDF file(s) to add')
|
)
|
||||||
|
add_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
|
||||||
|
add_parser.add_argument(
|
||||||
|
"file_paths", nargs="+", help="Path(s) to the PDF file(s) to add"
|
||||||
|
)
|
||||||
|
|
||||||
# Query command
|
# Query command
|
||||||
query_parser = subparsers.add_parser('query', aliases=['q'], help='Query the search database')
|
query_parser = subparsers.add_parser(
|
||||||
query_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
|
"query", aliases=["q"], help="Query the search database"
|
||||||
query_parser.add_argument('query_text', help='Text to search for')
|
)
|
||||||
|
query_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
|
||||||
|
query_parser.add_argument("query_text", help="Text to search for")
|
||||||
|
|
||||||
# Host command (web server)
|
# Host command (web server)
|
||||||
host_parser = subparsers.add_parser('host', aliases=['h'], help='Start a web server for semantic search')
|
host_parser = subparsers.add_parser(
|
||||||
host_parser.add_argument('db', help='Path to the database file (e.g., db.pkl)')
|
"host", aliases=["h"], help="Start a web server for semantic search"
|
||||||
host_parser.add_argument('--host', default='127.0.0.1', help='Host address to bind to (default: 127.0.0.1)')
|
)
|
||||||
host_parser.add_argument('--port', type=int, default=5000, help='Port to listen on (default: 5000)')
|
host_parser.add_argument("db", help="Path to the database file (e.g., db.pkl)")
|
||||||
|
host_parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
default="127.0.0.1",
|
||||||
|
help="Host address to bind to (default: 127.0.0.1)",
|
||||||
|
)
|
||||||
|
host_parser.add_argument(
|
||||||
|
"--port", type=int, default=5000, help="Port to listen on (default: 5000)"
|
||||||
|
)
|
||||||
|
|
||||||
# Test command
|
# Test command
|
||||||
subparsers.add_parser('test', aliases=['t'], help='Test database save/load functionality')
|
subparsers.add_parser(
|
||||||
|
"test", aliases=["t"], help="Test database save/load functionality"
|
||||||
|
)
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Handle commands
|
# Handle commands
|
||||||
if args.command in ['create', 'c']:
|
if args.command in ["create", "c"]:
|
||||||
create_database(args.db_path)
|
create_database(args.db_path)
|
||||||
elif args.command in ['add-file', 'a']:
|
elif args.command in ["add-file", "a"]:
|
||||||
add_file(args.db, args.file_paths)
|
add_file(args.db, args.file_paths)
|
||||||
elif args.command in ['query', 'q']:
|
elif args.command in ["query", "q"]:
|
||||||
query(args.db, args.query_text)
|
query(args.db, args.query_text)
|
||||||
elif args.command in ['host', 'h']:
|
elif args.command in ["host", "h"]:
|
||||||
start_web_server(args.db, args.host, args.port)
|
start_web_server(args.db, args.host, args.port)
|
||||||
elif args.command in ['test', 't']:
|
elif args.command in ["test", "t"]:
|
||||||
test_database()
|
test_database()
|
||||||
else:
|
else:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
|
|||||||
Reference in New Issue
Block a user