From 2932d95c661f340af03f78b016ea2dd10727d296 Mon Sep 17 00:00:00 2001 From: Jan Mrna Date: Sat, 20 Sep 2025 12:16:42 +0200 Subject: [PATCH] Added base class for PathFinder, fixed typing errors --- python/pathfinding_demo.py | 97 +++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/python/pathfinding_demo.py b/python/pathfinding_demo.py index 0271ef5..58f035f 100644 --- a/python/pathfinding_demo.py +++ b/python/pathfinding_demo.py @@ -8,7 +8,8 @@ import matplotlib.pyplot as plt import numpy as np import time -from typing import Protocol, Optional +from typing import Optional +from abc import ABC, abstractmethod # # Type and interfaces definition @@ -16,15 +17,13 @@ from typing import Protocol, Optional type Point2D = tuple[int, int] # tuple(x, y) type Path = list[Point2D] -type ElapsedTime_ns = int # nanoseconds -type VisitedNodeCount = int class Map: """ 2D map consisting of cells with given cost """ # array not defined as private, as plotting utilities work with it directly - array: np.array + array: np.ndarray _visited_nodes: int def __init__(self, width: int, height: int) -> None: @@ -60,6 +59,14 @@ class Map: points.append(p) return points + def GetPointCost(self, point: Point2D) -> float: + x, y = point + row, col = y, x + return self.array[(row, col)] + + def GetPathCost(self, path: Path) -> float: + return sum([self.GetPointCost(p) for p in path]) + def ResetVisitedCount(self) -> None: self._visited_nodes = 0 @@ -68,32 +75,12 @@ class Map: def Visit(self, point: Point2D) -> float: """ - Visit the node, returning its cost + Visit the node and return its cost """ if not self.IsPointValid(point): raise ValueError("Point out of bounds") self._visited_nodes += 1 - x, y = point - row, col = y, x - return self.array[(row, col)] - - -class PathFinder(Protocol): - def SetMap(m: Map) -> None: - ... - - def CalculatePath(start: Point2D, end: Point2D) -> Path: - """ - Calculate path on a given map. - Note: map must be set first using SetMap (or using constructor) - """ - - def GetStats() -> (ElapsedTime_ns, VisitedNodeCount): - """ - Return performance stats for the last calculation: - - elapsed time in nanoseconds, - - number of visited nodes during search - """ + return self.GetPointCost(point) # # Drawing utilities @@ -129,50 +116,68 @@ class Visualizer: self._axes.plot(xs[-1], ys[-1], 'o', color='magenta', markersize=8) # end point # -# Method: depth-first search +# Pathfinding implementations # -class DFS: - - name = "Depth First Search" +class PathFinderBase(ABC): + name: str _map: Optional[Map] _elapsed_time_ns: int _visited_node_count: int def __init__(self) -> None: self._map = None + self._elapsed_time_ns = 0 + self._visited_node_count = 0 + def SetMap(self, m: Map) -> None: self._map = m - def CalculatePath(self, start: Point2D, end: Point2D) -> Path: + def CalculatePath(self, start: Point2D, end: Point2D) -> Optional[Path]: + """ + Calculate path on a given map. + Note: map must be set first using SetMap + """ assert self._map is not None, "SetMap must be called first" self._map.ResetVisitedCount() start_time = time.perf_counter_ns() - res = self._CalculatePath(start, end) - print(f"{res=}") - stop_time = time.perf_counter_ns() self._elapsed_time_ns = stop_time - start_time self._visited_node_count = self._map.GetVisitedCount() return res - def GetStats(self) -> (ElapsedTime_ns, VisitedNodeCount): + @abstractmethod + def _CalculatePath(self, start: Point2D, end: Point2D) -> Optional[Path]: + """ + This method must be implemented by the derived classes + """ + + def GetStats(self) -> tuple[int, int]: + """ + Return performance stats for the last calculation: + - elapsed time in nanoseconds, + - number of visited nodes during search + """ return self._elapsed_time_ns, self._visited_node_count + +class DFS(PathFinderBase): + + name = "Depth First Search" + def _CalculatePath(self, point: Point2D, end_point: Point2D, path: Optional[list[Point2D]] = None, visited: Optional[set[Point2D]] = None) -> Optional[Path]: - """ - Find (any) path, not guaranteed to be optimal (and it most probably won't be) - """ if visited is None: visited = set() if path is None: path = list() + if self._map is None: + return None # to make mypy happy # We don't need to know cost in this case, but we still want to track # how many nodes we've visited _ = self._map.Visit(point) @@ -198,12 +203,12 @@ def main(): # Define the map and start/stop points m = Map(15,10) m.Randomize() - starting_point: Point2D = (0,9) - end_point: Point2D = (14,1) + starting_point: Point2D = (14,1) + end_point: Point2D = (0,9) - path_finder_classes: list[PathFinder] = { + path_finder_classes: list[type[PathFinderBase]] = [ DFS, - } + ] v = Visualizer() v.DrawMap(m) @@ -213,8 +218,12 @@ def main(): path_finder.SetMap(m) path = path_finder.CalculatePath(starting_point, end_point) elapsed_time, visited_nodes = path_finder.GetStats() - print(f"{path_finder.name:22}: took {elapsed_time/1e6} ms, visited {visited_nodes} nodes") - v.DrawPath(path) + if path is not None: + cost = m.GetPathCost(path) + print(f"{path_finder.name:22}: took {elapsed_time/1e6:.3f} ms, visited {visited_nodes} nodes, cost {cost:.2f}") + v.DrawPath(path) + else: + print(f"{path_finder.name}: No path found") plt.show()