Added base class for PathFinder, fixed typing errors
This commit is contained in:
parent
29f08036c2
commit
2932d95c66
@ -8,7 +8,8 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
from typing import Protocol, Optional
|
from typing import Optional
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
#
|
#
|
||||||
# Type and interfaces definition
|
# Type and interfaces definition
|
||||||
@ -16,15 +17,13 @@ from typing import Protocol, Optional
|
|||||||
|
|
||||||
type Point2D = tuple[int, int] # tuple(x, y)
|
type Point2D = tuple[int, int] # tuple(x, y)
|
||||||
type Path = list[Point2D]
|
type Path = list[Point2D]
|
||||||
type ElapsedTime_ns = int # nanoseconds
|
|
||||||
type VisitedNodeCount = int
|
|
||||||
|
|
||||||
class Map:
|
class Map:
|
||||||
"""
|
"""
|
||||||
2D map consisting of cells with given cost
|
2D map consisting of cells with given cost
|
||||||
"""
|
"""
|
||||||
# array not defined as private, as plotting utilities work with it directly
|
# array not defined as private, as plotting utilities work with it directly
|
||||||
array: np.array
|
array: np.ndarray
|
||||||
_visited_nodes: int
|
_visited_nodes: int
|
||||||
|
|
||||||
def __init__(self, width: int, height: int) -> None:
|
def __init__(self, width: int, height: int) -> None:
|
||||||
@ -60,6 +59,14 @@ class Map:
|
|||||||
points.append(p)
|
points.append(p)
|
||||||
return points
|
return points
|
||||||
|
|
||||||
|
def GetPointCost(self, point: Point2D) -> float:
|
||||||
|
x, y = point
|
||||||
|
row, col = y, x
|
||||||
|
return self.array[(row, col)]
|
||||||
|
|
||||||
|
def GetPathCost(self, path: Path) -> float:
|
||||||
|
return sum([self.GetPointCost(p) for p in path])
|
||||||
|
|
||||||
def ResetVisitedCount(self) -> None:
|
def ResetVisitedCount(self) -> None:
|
||||||
self._visited_nodes = 0
|
self._visited_nodes = 0
|
||||||
|
|
||||||
@ -68,32 +75,12 @@ class Map:
|
|||||||
|
|
||||||
def Visit(self, point: Point2D) -> float:
|
def Visit(self, point: Point2D) -> float:
|
||||||
"""
|
"""
|
||||||
Visit the node, returning its cost
|
Visit the node and return its cost
|
||||||
"""
|
"""
|
||||||
if not self.IsPointValid(point):
|
if not self.IsPointValid(point):
|
||||||
raise ValueError("Point out of bounds")
|
raise ValueError("Point out of bounds")
|
||||||
self._visited_nodes += 1
|
self._visited_nodes += 1
|
||||||
x, y = point
|
return self.GetPointCost(point)
|
||||||
row, col = y, x
|
|
||||||
return self.array[(row, col)]
|
|
||||||
|
|
||||||
|
|
||||||
class PathFinder(Protocol):
|
|
||||||
def SetMap(m: Map) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
def CalculatePath(start: Point2D, end: Point2D) -> Path:
|
|
||||||
"""
|
|
||||||
Calculate path on a given map.
|
|
||||||
Note: map must be set first using SetMap (or using constructor)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def GetStats() -> (ElapsedTime_ns, VisitedNodeCount):
|
|
||||||
"""
|
|
||||||
Return performance stats for the last calculation:
|
|
||||||
- elapsed time in nanoseconds,
|
|
||||||
- number of visited nodes during search
|
|
||||||
"""
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Drawing utilities
|
# Drawing utilities
|
||||||
@ -129,50 +116,68 @@ class Visualizer:
|
|||||||
self._axes.plot(xs[-1], ys[-1], 'o', color='magenta', markersize=8) # end point
|
self._axes.plot(xs[-1], ys[-1], 'o', color='magenta', markersize=8) # end point
|
||||||
|
|
||||||
#
|
#
|
||||||
# Method: depth-first search
|
# Pathfinding implementations
|
||||||
#
|
#
|
||||||
|
|
||||||
class DFS:
|
class PathFinderBase(ABC):
|
||||||
|
name: str
|
||||||
name = "Depth First Search"
|
|
||||||
_map: Optional[Map]
|
_map: Optional[Map]
|
||||||
_elapsed_time_ns: int
|
_elapsed_time_ns: int
|
||||||
_visited_node_count: int
|
_visited_node_count: int
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._map = None
|
self._map = None
|
||||||
|
self._elapsed_time_ns = 0
|
||||||
|
self._visited_node_count = 0
|
||||||
|
|
||||||
|
|
||||||
def SetMap(self, m: Map) -> None:
|
def SetMap(self, m: Map) -> None:
|
||||||
self._map = m
|
self._map = m
|
||||||
|
|
||||||
def CalculatePath(self, start: Point2D, end: Point2D) -> Path:
|
def CalculatePath(self, start: Point2D, end: Point2D) -> Optional[Path]:
|
||||||
|
"""
|
||||||
|
Calculate path on a given map.
|
||||||
|
Note: map must be set first using SetMap
|
||||||
|
"""
|
||||||
assert self._map is not None, "SetMap must be called first"
|
assert self._map is not None, "SetMap must be called first"
|
||||||
self._map.ResetVisitedCount()
|
self._map.ResetVisitedCount()
|
||||||
start_time = time.perf_counter_ns()
|
start_time = time.perf_counter_ns()
|
||||||
|
|
||||||
res = self._CalculatePath(start, end)
|
res = self._CalculatePath(start, end)
|
||||||
print(f"{res=}")
|
|
||||||
|
|
||||||
stop_time = time.perf_counter_ns()
|
stop_time = time.perf_counter_ns()
|
||||||
self._elapsed_time_ns = stop_time - start_time
|
self._elapsed_time_ns = stop_time - start_time
|
||||||
self._visited_node_count = self._map.GetVisitedCount()
|
self._visited_node_count = self._map.GetVisitedCount()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def GetStats(self) -> (ElapsedTime_ns, VisitedNodeCount):
|
@abstractmethod
|
||||||
|
def _CalculatePath(self, start: Point2D, end: Point2D) -> Optional[Path]:
|
||||||
|
"""
|
||||||
|
This method must be implemented by the derived classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
def GetStats(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Return performance stats for the last calculation:
|
||||||
|
- elapsed time in nanoseconds,
|
||||||
|
- number of visited nodes during search
|
||||||
|
"""
|
||||||
return self._elapsed_time_ns, self._visited_node_count
|
return self._elapsed_time_ns, self._visited_node_count
|
||||||
|
|
||||||
|
|
||||||
|
class DFS(PathFinderBase):
|
||||||
|
|
||||||
|
name = "Depth First Search"
|
||||||
|
|
||||||
def _CalculatePath(self,
|
def _CalculatePath(self,
|
||||||
point: Point2D,
|
point: Point2D,
|
||||||
end_point: Point2D,
|
end_point: Point2D,
|
||||||
path: Optional[list[Point2D]] = None,
|
path: Optional[list[Point2D]] = None,
|
||||||
visited: Optional[set[Point2D]] = None) -> Optional[Path]:
|
visited: Optional[set[Point2D]] = None) -> Optional[Path]:
|
||||||
"""
|
|
||||||
Find (any) path, not guaranteed to be optimal (and it most probably won't be)
|
|
||||||
"""
|
|
||||||
if visited is None:
|
if visited is None:
|
||||||
visited = set()
|
visited = set()
|
||||||
if path is None:
|
if path is None:
|
||||||
path = list()
|
path = list()
|
||||||
|
if self._map is None:
|
||||||
|
return None # to make mypy happy
|
||||||
# We don't need to know cost in this case, but we still want to track
|
# We don't need to know cost in this case, but we still want to track
|
||||||
# how many nodes we've visited
|
# how many nodes we've visited
|
||||||
_ = self._map.Visit(point)
|
_ = self._map.Visit(point)
|
||||||
@ -198,12 +203,12 @@ def main():
|
|||||||
# Define the map and start/stop points
|
# Define the map and start/stop points
|
||||||
m = Map(15,10)
|
m = Map(15,10)
|
||||||
m.Randomize()
|
m.Randomize()
|
||||||
starting_point: Point2D = (0,9)
|
starting_point: Point2D = (14,1)
|
||||||
end_point: Point2D = (14,1)
|
end_point: Point2D = (0,9)
|
||||||
|
|
||||||
path_finder_classes: list[PathFinder] = {
|
path_finder_classes: list[type[PathFinderBase]] = [
|
||||||
DFS,
|
DFS,
|
||||||
}
|
]
|
||||||
|
|
||||||
v = Visualizer()
|
v = Visualizer()
|
||||||
v.DrawMap(m)
|
v.DrawMap(m)
|
||||||
@ -213,8 +218,12 @@ def main():
|
|||||||
path_finder.SetMap(m)
|
path_finder.SetMap(m)
|
||||||
path = path_finder.CalculatePath(starting_point, end_point)
|
path = path_finder.CalculatePath(starting_point, end_point)
|
||||||
elapsed_time, visited_nodes = path_finder.GetStats()
|
elapsed_time, visited_nodes = path_finder.GetStats()
|
||||||
print(f"{path_finder.name:22}: took {elapsed_time/1e6} ms, visited {visited_nodes} nodes")
|
if path is not None:
|
||||||
v.DrawPath(path)
|
cost = m.GetPathCost(path)
|
||||||
|
print(f"{path_finder.name:22}: took {elapsed_time/1e6:.3f} ms, visited {visited_nodes} nodes, cost {cost:.2f}")
|
||||||
|
v.DrawPath(path)
|
||||||
|
else:
|
||||||
|
print(f"{path_finder.name}: No path found")
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user