Fixed warnings from pylint
This commit is contained in:
33
db.py
33
db.py
@@ -1,3 +1,7 @@
|
|||||||
|
#pylint: disable=missing-class-docstring,invalid-name,broad-exception-caught
|
||||||
|
"""
|
||||||
|
Database module for semantic document search tool.
|
||||||
|
"""
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -6,13 +10,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pymupdf
|
import pymupdf
|
||||||
import ollama # TODO split to another file
|
import ollama
|
||||||
|
|
||||||
#
|
#
|
||||||
# Types
|
# Types
|
||||||
#
|
#
|
||||||
|
|
||||||
type Vector = np.NDArray # np.NDArray[np.float32] ?
|
type Vector = np.NDArray
|
||||||
type VectorBytes = bytes
|
type VectorBytes = bytes
|
||||||
|
|
||||||
|
|
||||||
@@ -90,12 +94,29 @@ def _vectorize_record(record: Record) -> tuple[Record, Vector]:
|
|||||||
return record, _embed(record.text)
|
return record, _embed(record.text)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding() -> bool:
|
||||||
|
"""
|
||||||
|
Test if embedding functionality is available and working.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if embedding is working, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_ = _embed("Test.")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# High-level (exported) functions
|
# High-level (exported) functions
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def create_dummy() -> Database:
|
def create_dummy() -> Database:
|
||||||
|
"""
|
||||||
|
Create a dummy database for testing purposes.
|
||||||
|
"""
|
||||||
db_length: Final[int] = 10
|
db_length: Final[int] = 10
|
||||||
vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
|
vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
|
||||||
records = {
|
records = {
|
||||||
@@ -250,13 +271,11 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
|||||||
records: list[Record] = []
|
records: list[Record] = []
|
||||||
chunk_size = 1024
|
chunk_size = 1024
|
||||||
|
|
||||||
for page_num in range(len(doc)):
|
for page_num, page in enumerate(doc):
|
||||||
page = doc[page_num]
|
|
||||||
text = page.get_text().strip()
|
text = page.get_text().strip()
|
||||||
if not text:
|
if not text:
|
||||||
print(f" Page {page_num + 1}: Skipped (empty)")
|
print(f" Page {page_num + 1}: Skipped (empty)")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Simple chunking - split text into chunks of specified size
|
# Simple chunking - split text into chunks of specified size
|
||||||
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
|
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
|
||||||
chunk = text[i : i + chunk_size]
|
chunk = text[i : i + chunk_size]
|
||||||
@@ -267,14 +286,14 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
|
|||||||
)
|
)
|
||||||
doc.close()
|
doc.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error processing PDF {file}: {e}")
|
raise RuntimeError(f"Error processing PDF {file}: {e}") from e
|
||||||
|
|
||||||
# Process chunks in parallel
|
# Process chunks in parallel
|
||||||
print(f"Processing {len(records)} chunks with {max_workers} workers...")
|
print(f"Processing {len(records)} chunks with {max_workers} workers...")
|
||||||
|
|
||||||
db.documents.append(file)
|
db.documents.append(file)
|
||||||
|
|
||||||
# TODO measure with GIL disabled to check if multithreading actually helps
|
# NOTE this will only help with GIL disabled
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
futures = [pool.submit(_vectorize_record, r) for r in records]
|
futures = [pool.submit(_vectorize_record, r) for r in records]
|
||||||
for f in as_completed(futures):
|
for f in as_completed(futures):
|
||||||
|
|||||||
34
main.py
34
main.py
@@ -1,9 +1,15 @@
|
|||||||
import argparse
|
#pylint: disable=broad-exception-caught
|
||||||
|
"""
|
||||||
|
Semantic search tool main script.
|
||||||
|
Provides command-line interface and web server for creating, adding and querying the database.
|
||||||
|
"""
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import argparse
|
||||||
from typing import Final
|
from typing import Final
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import db
|
import db
|
||||||
|
|
||||||
@@ -88,17 +94,10 @@ def test_database():
|
|||||||
|
|
||||||
# Test embedding functionality
|
# Test embedding functionality
|
||||||
print("\n5. Testing embedding functionality (Ollama API server)...")
|
print("\n5. Testing embedding functionality (Ollama API server)...")
|
||||||
try:
|
embedding_ok = db.test_embedding()
|
||||||
test_embedding = db._embed("This is a test text for embedding.")
|
print(f" Embedding test {'PASSED' if embedding_ok else 'FAILED'}")
|
||||||
print(
|
if not embedding_ok:
|
||||||
f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}"
|
print(" Did you start ollama docker image?")
|
||||||
)
|
|
||||||
ollama_running = True
|
|
||||||
except Exception as e:
|
|
||||||
print(
|
|
||||||
f" Embedding test FAILED: {e}\n Did you start ollama docker image?"
|
|
||||||
)
|
|
||||||
ollama_running = False
|
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
all_good = (
|
all_good = (
|
||||||
@@ -106,7 +105,7 @@ def test_database():
|
|||||||
and records_match
|
and records_match
|
||||||
and vectors_equal
|
and vectors_equal
|
||||||
and records_equal
|
and records_equal
|
||||||
and ollama_running
|
and embedding_ok
|
||||||
)
|
)
|
||||||
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
|
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
|
||||||
|
|
||||||
@@ -204,6 +203,8 @@ def query(db_path: str, query_text: str):
|
|||||||
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
|
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
|
||||||
"""Start a web server for the semantic search tool."""
|
"""Start a web server for the semantic search tool."""
|
||||||
try:
|
try:
|
||||||
|
# here we intentionally import inside the function to avoid Flask dependency for CLI usage
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
from flask import Flask, request, jsonify, render_template, send_file
|
from flask import Flask, request, jsonify, render_template, send_file
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("❌ Flask not found. Please install it first:")
|
print("❌ Flask not found. Please install it first:")
|
||||||
@@ -282,6 +283,9 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
"""
|
||||||
|
Main function to parse command-line arguments and execute commands.
|
||||||
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Semantic Search Tool",
|
description="Semantic Search Tool",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
|||||||
Reference in New Issue
Block a user