Added base class for PathFinder, fixed typing errors

This commit is contained in:
Jan Mrna 2025-09-20 12:16:42 +02:00
parent 29f08036c2
commit 2932d95c66

View File

@ -8,7 +8,8 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import time import time
from typing import Protocol, Optional from typing import Optional
from abc import ABC, abstractmethod
# #
# Type and interfaces definition # Type and interfaces definition
@ -16,15 +17,13 @@ from typing import Protocol, Optional
type Point2D = tuple[int, int] # tuple(x, y) type Point2D = tuple[int, int] # tuple(x, y)
type Path = list[Point2D] type Path = list[Point2D]
type ElapsedTime_ns = int # nanoseconds
type VisitedNodeCount = int
class Map: class Map:
""" """
2D map consisting of cells with given cost 2D map consisting of cells with given cost
""" """
# array not defined as private, as plotting utilities work with it directly # array not defined as private, as plotting utilities work with it directly
array: np.array array: np.ndarray
_visited_nodes: int _visited_nodes: int
def __init__(self, width: int, height: int) -> None: def __init__(self, width: int, height: int) -> None:
@ -60,6 +59,14 @@ class Map:
points.append(p) points.append(p)
return points return points
def GetPointCost(self, point: Point2D) -> float:
x, y = point
row, col = y, x
return self.array[(row, col)]
def GetPathCost(self, path: Path) -> float:
return sum([self.GetPointCost(p) for p in path])
def ResetVisitedCount(self) -> None: def ResetVisitedCount(self) -> None:
self._visited_nodes = 0 self._visited_nodes = 0
@ -68,32 +75,12 @@ class Map:
def Visit(self, point: Point2D) -> float: def Visit(self, point: Point2D) -> float:
""" """
Visit the node, returning its cost Visit the node and return its cost
""" """
if not self.IsPointValid(point): if not self.IsPointValid(point):
raise ValueError("Point out of bounds") raise ValueError("Point out of bounds")
self._visited_nodes += 1 self._visited_nodes += 1
x, y = point return self.GetPointCost(point)
row, col = y, x
return self.array[(row, col)]
class PathFinder(Protocol):
def SetMap(m: Map) -> None:
...
def CalculatePath(start: Point2D, end: Point2D) -> Path:
"""
Calculate path on a given map.
Note: map must be set first using SetMap (or using constructor)
"""
def GetStats() -> (ElapsedTime_ns, VisitedNodeCount):
"""
Return performance stats for the last calculation:
- elapsed time in nanoseconds,
- number of visited nodes during search
"""
# #
# Drawing utilities # Drawing utilities
@ -129,50 +116,68 @@ class Visualizer:
self._axes.plot(xs[-1], ys[-1], 'o', color='magenta', markersize=8) # end point self._axes.plot(xs[-1], ys[-1], 'o', color='magenta', markersize=8) # end point
# #
# Method: depth-first search # Pathfinding implementations
# #
class DFS: class PathFinderBase(ABC):
name: str
name = "Depth First Search"
_map: Optional[Map] _map: Optional[Map]
_elapsed_time_ns: int _elapsed_time_ns: int
_visited_node_count: int _visited_node_count: int
def __init__(self) -> None: def __init__(self) -> None:
self._map = None self._map = None
self._elapsed_time_ns = 0
self._visited_node_count = 0
def SetMap(self, m: Map) -> None: def SetMap(self, m: Map) -> None:
self._map = m self._map = m
def CalculatePath(self, start: Point2D, end: Point2D) -> Path: def CalculatePath(self, start: Point2D, end: Point2D) -> Optional[Path]:
"""
Calculate path on a given map.
Note: map must be set first using SetMap
"""
assert self._map is not None, "SetMap must be called first" assert self._map is not None, "SetMap must be called first"
self._map.ResetVisitedCount() self._map.ResetVisitedCount()
start_time = time.perf_counter_ns() start_time = time.perf_counter_ns()
res = self._CalculatePath(start, end) res = self._CalculatePath(start, end)
print(f"{res=}")
stop_time = time.perf_counter_ns() stop_time = time.perf_counter_ns()
self._elapsed_time_ns = stop_time - start_time self._elapsed_time_ns = stop_time - start_time
self._visited_node_count = self._map.GetVisitedCount() self._visited_node_count = self._map.GetVisitedCount()
return res return res
def GetStats(self) -> (ElapsedTime_ns, VisitedNodeCount): @abstractmethod
def _CalculatePath(self, start: Point2D, end: Point2D) -> Optional[Path]:
"""
This method must be implemented by the derived classes
"""
def GetStats(self) -> tuple[int, int]:
"""
Return performance stats for the last calculation:
- elapsed time in nanoseconds,
- number of visited nodes during search
"""
return self._elapsed_time_ns, self._visited_node_count return self._elapsed_time_ns, self._visited_node_count
class DFS(PathFinderBase):
name = "Depth First Search"
def _CalculatePath(self, def _CalculatePath(self,
point: Point2D, point: Point2D,
end_point: Point2D, end_point: Point2D,
path: Optional[list[Point2D]] = None, path: Optional[list[Point2D]] = None,
visited: Optional[set[Point2D]] = None) -> Optional[Path]: visited: Optional[set[Point2D]] = None) -> Optional[Path]:
"""
Find (any) path, not guaranteed to be optimal (and it most probably won't be)
"""
if visited is None: if visited is None:
visited = set() visited = set()
if path is None: if path is None:
path = list() path = list()
if self._map is None:
return None # to make mypy happy
# We don't need to know cost in this case, but we still want to track # We don't need to know cost in this case, but we still want to track
# how many nodes we've visited # how many nodes we've visited
_ = self._map.Visit(point) _ = self._map.Visit(point)
@ -198,12 +203,12 @@ def main():
# Define the map and start/stop points # Define the map and start/stop points
m = Map(15,10) m = Map(15,10)
m.Randomize() m.Randomize()
starting_point: Point2D = (0,9) starting_point: Point2D = (14,1)
end_point: Point2D = (14,1) end_point: Point2D = (0,9)
path_finder_classes: list[PathFinder] = { path_finder_classes: list[type[PathFinderBase]] = [
DFS, DFS,
} ]
v = Visualizer() v = Visualizer()
v.DrawMap(m) v.DrawMap(m)
@ -213,8 +218,12 @@ def main():
path_finder.SetMap(m) path_finder.SetMap(m)
path = path_finder.CalculatePath(starting_point, end_point) path = path_finder.CalculatePath(starting_point, end_point)
elapsed_time, visited_nodes = path_finder.GetStats() elapsed_time, visited_nodes = path_finder.GetStats()
print(f"{path_finder.name:22}: took {elapsed_time/1e6} ms, visited {visited_nodes} nodes") if path is not None:
v.DrawPath(path) cost = m.GetPathCost(path)
print(f"{path_finder.name:22}: took {elapsed_time/1e6:.3f} ms, visited {visited_nodes} nodes, cost {cost:.2f}")
v.DrawPath(path)
else:
print(f"{path_finder.name}: No path found")
plt.show() plt.show()