Added base class for PathFinder, fixed typing errors

This commit is contained in:
Jan Mrna 2025-09-20 12:16:42 +02:00
parent 29f08036c2
commit 2932d95c66

View File

@ -8,7 +8,8 @@
import matplotlib.pyplot as plt
import numpy as np
import time
from typing import Protocol, Optional
from typing import Optional
from abc import ABC, abstractmethod
#
# Type and interfaces definition
@ -16,15 +17,13 @@ from typing import Protocol, Optional
type Point2D = tuple[int, int] # tuple(x, y)
type Path = list[Point2D]
type ElapsedTime_ns = int # nanoseconds
type VisitedNodeCount = int
class Map:
"""
2D map consisting of cells with given cost
"""
# array not defined as private, as plotting utilities work with it directly
array: np.array
array: np.ndarray
_visited_nodes: int
def __init__(self, width: int, height: int) -> None:
@ -60,6 +59,14 @@ class Map:
points.append(p)
return points
def GetPointCost(self, point: Point2D) -> float:
x, y = point
row, col = y, x
return self.array[(row, col)]
def GetPathCost(self, path: Path) -> float:
return sum([self.GetPointCost(p) for p in path])
def ResetVisitedCount(self) -> None:
self._visited_nodes = 0
@ -68,32 +75,12 @@ class Map:
def Visit(self, point: Point2D) -> float:
"""
Visit the node, returning its cost
Visit the node and return its cost
"""
if not self.IsPointValid(point):
raise ValueError("Point out of bounds")
self._visited_nodes += 1
x, y = point
row, col = y, x
return self.array[(row, col)]
class PathFinder(Protocol):
def SetMap(m: Map) -> None:
...
def CalculatePath(start: Point2D, end: Point2D) -> Path:
"""
Calculate path on a given map.
Note: map must be set first using SetMap (or using constructor)
"""
def GetStats() -> (ElapsedTime_ns, VisitedNodeCount):
"""
Return performance stats for the last calculation:
- elapsed time in nanoseconds,
- number of visited nodes during search
"""
return self.GetPointCost(point)
#
# Drawing utilities
@ -129,50 +116,68 @@ class Visualizer:
self._axes.plot(xs[-1], ys[-1], 'o', color='magenta', markersize=8) # end point
#
# Method: depth-first search
# Pathfinding implementations
#
class DFS:
name = "Depth First Search"
class PathFinderBase(ABC):
name: str
_map: Optional[Map]
_elapsed_time_ns: int
_visited_node_count: int
def __init__(self) -> None:
self._map = None
self._elapsed_time_ns = 0
self._visited_node_count = 0
def SetMap(self, m: Map) -> None:
self._map = m
def CalculatePath(self, start: Point2D, end: Point2D) -> Path:
def CalculatePath(self, start: Point2D, end: Point2D) -> Optional[Path]:
"""
Calculate path on a given map.
Note: map must be set first using SetMap
"""
assert self._map is not None, "SetMap must be called first"
self._map.ResetVisitedCount()
start_time = time.perf_counter_ns()
res = self._CalculatePath(start, end)
print(f"{res=}")
stop_time = time.perf_counter_ns()
self._elapsed_time_ns = stop_time - start_time
self._visited_node_count = self._map.GetVisitedCount()
return res
def GetStats(self) -> (ElapsedTime_ns, VisitedNodeCount):
@abstractmethod
def _CalculatePath(self, start: Point2D, end: Point2D) -> Optional[Path]:
"""
This method must be implemented by the derived classes
"""
def GetStats(self) -> tuple[int, int]:
"""
Return performance stats for the last calculation:
- elapsed time in nanoseconds,
- number of visited nodes during search
"""
return self._elapsed_time_ns, self._visited_node_count
class DFS(PathFinderBase):
name = "Depth First Search"
def _CalculatePath(self,
point: Point2D,
end_point: Point2D,
path: Optional[list[Point2D]] = None,
visited: Optional[set[Point2D]] = None) -> Optional[Path]:
"""
Find (any) path, not guaranteed to be optimal (and it most probably won't be)
"""
if visited is None:
visited = set()
if path is None:
path = list()
if self._map is None:
return None # to make mypy happy
# We don't need to know cost in this case, but we still want to track
# how many nodes we've visited
_ = self._map.Visit(point)
@ -198,12 +203,12 @@ def main():
# Define the map and start/stop points
m = Map(15,10)
m.Randomize()
starting_point: Point2D = (0,9)
end_point: Point2D = (14,1)
starting_point: Point2D = (14,1)
end_point: Point2D = (0,9)
path_finder_classes: list[PathFinder] = {
path_finder_classes: list[type[PathFinderBase]] = [
DFS,
}
]
v = Visualizer()
v.DrawMap(m)
@ -213,8 +218,12 @@ def main():
path_finder.SetMap(m)
path = path_finder.CalculatePath(starting_point, end_point)
elapsed_time, visited_nodes = path_finder.GetStats()
print(f"{path_finder.name:22}: took {elapsed_time/1e6} ms, visited {visited_nodes} nodes")
v.DrawPath(path)
if path is not None:
cost = m.GetPathCost(path)
print(f"{path_finder.name:22}: took {elapsed_time/1e6:.3f} ms, visited {visited_nodes} nodes, cost {cost:.2f}")
v.DrawPath(path)
else:
print(f"{path_finder.name}: No path found")
plt.show()