Fixed warnings from pylint

This commit is contained in:
Jan Mrna
2025-11-06 10:58:18 +01:00
parent e352780a3d
commit 06dc4c5e1f
2 changed files with 45 additions and 22 deletions

33
db.py
View File

@@ -1,3 +1,7 @@
#pylint: disable=missing-class-docstring,invalid-name,broad-exception-caught
"""
Database module for semantic document search tool.
"""
import pickle
from pathlib import Path
from dataclasses import dataclass
@@ -6,13 +10,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import pymupdf
import ollama # TODO split to another file
import ollama
#
# Types
#
type Vector = np.NDArray # np.NDArray[np.float32] ?
type Vector = np.NDArray
type VectorBytes = bytes
@@ -90,12 +94,29 @@ def _vectorize_record(record: Record) -> tuple[Record, Vector]:
return record, _embed(record.text)
def test_embedding() -> bool:
"""
Test if embedding functionality is available and working.
Returns:
bool: True if embedding is working, False otherwise
"""
try:
_ = _embed("Test.")
return True
except Exception:
return False
#
# High-level (exported) functions
#
def create_dummy() -> Database:
"""
Create a dummy database for testing purposes.
"""
db_length: Final[int] = 10
vectors = [np.array([i, 2 * i, 3 * i, 4 * i]) for i in range(db_length)]
records = {
@@ -250,13 +271,11 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
records: list[Record] = []
chunk_size = 1024
for page_num in range(len(doc)):
page = doc[page_num]
for page_num, page in enumerate(doc):
text = page.get_text().strip()
if not text:
print(f" Page {page_num + 1}: Skipped (empty)")
continue
# Simple chunking - split text into chunks of specified size
for chunk_idx, i in enumerate(range(0, len(text), chunk_size)):
chunk = text[i : i + chunk_size]
@@ -267,14 +286,14 @@ def add_document(db: Database | Path, file: Path, max_workers: int = 4) -> None:
)
doc.close()
except Exception as e:
raise RuntimeError(f"Error processing PDF {file}: {e}")
raise RuntimeError(f"Error processing PDF {file}: {e}") from e
# Process chunks in parallel
print(f"Processing {len(records)} chunks with {max_workers} workers...")
db.documents.append(file)
# TODO measure with GIL disabled to check if multithreading actually helps
# NOTE this will only help with GIL disabled
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_vectorize_record, r) for r in records]
for f in as_completed(futures):

34
main.py
View File

@@ -1,9 +1,15 @@
import argparse
#pylint: disable=broad-exception-caught
"""
Semantic search tool main script.
Provides command-line interface and web server for creating, adding and querying the database.
"""
import sys
from pathlib import Path
import tempfile
import numpy as np
import argparse
from typing import Final
from pathlib import Path
import numpy as np
import db
@@ -88,17 +94,10 @@ def test_database():
# Test embedding functionality
print("\n5. Testing embedding functionality (Ollama API server)...")
try:
test_embedding = db._embed("This is a test text for embedding.")
print(
f" Embedding test PASSED: Generated vector of shape {test_embedding.shape}"
)
ollama_running = True
except Exception as e:
print(
f" Embedding test FAILED: {e}\n Did you start ollama docker image?"
)
ollama_running = False
embedding_ok = db.test_embedding()
print(f" Embedding test {'PASSED' if embedding_ok else 'FAILED'}")
if not embedding_ok:
print(" Did you start ollama docker image?")
# Summary
all_good = (
@@ -106,7 +105,7 @@ def test_database():
and records_match
and vectors_equal
and records_equal
and ollama_running
and embedding_ok
)
print(f"\n✅ Test {'PASSED' if all_good else 'FAILED'}")
@@ -204,6 +203,8 @@ def query(db_path: str, query_text: str):
def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
"""Start a web server for the semantic search tool."""
try:
# here we intentionally import inside the function to avoid Flask dependency for CLI usage
# pylint: disable=import-outside-toplevel
from flask import Flask, request, jsonify, render_template, send_file
except ImportError:
print("❌ Flask not found. Please install it first:")
@@ -282,6 +283,9 @@ def start_web_server(db_path: str, host: str = "127.0.0.1", port: int = 5000):
def main():
"""
Main function to parse command-line arguments and execute commands.
"""
parser = argparse.ArgumentParser(
description="Semantic Search Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,