Implemented Map.GetNeighbours

This commit is contained in:
Jan Mrna 2025-09-19 14:29:23 +02:00
parent a9be2e5bbc
commit 744ccaf478

View File

@ -7,6 +7,7 @@
import matplotlib.pyplot as plt
import numpy as np
import time
from typing import Protocol, Optional
#
@ -15,7 +16,7 @@ from typing import Protocol, Optional
type Point2D = tuple[int, int] # tuple(x, y)
type Path = list[Point2D]
type ElapsedTime_ns = float # nanoseconds
type ElapsedTime_ns = int # nanoseconds
type VisitedNodeCount = int
class Map:
@ -44,8 +45,17 @@ class Map:
y_in_bounds = (0 <= y < y_max)
return x_in_bounds and y_in_bounds
def GetNeighbours(self) -> list[Point2D]:
...
def GetNeighbours(self, center_point: Point2D) -> list[Point2D]:
points: list[Point2D] = []
x_center, y_center = center_point
for x in range(-1,2):
for y in range(-1,2):
if x == 0 and y == 0:
continue
p = (x + x_center, y + y_center)
if self.IsPointValid(p):
points.append(p)
return points
def ResetVisitedCount(self) -> None:
self._visited_nodes = 0
@ -53,7 +63,10 @@ class Map:
def GetVisitedCount(self) -> int:
return self._visited_nodes
def GetCost(self, point: Point2D) -> float:
def Visit(self, point: Point2D) -> float:
"""
Visit the node, returning its cost
"""
if not self.IsPointValid(point):
raise ValueError("Point out of bounds")
self._visited_nodes += 1
@ -96,10 +109,11 @@ class Visualizer:
def DrawMap(self, m: Map):
M, N = m.array.shape
_, ax = plt.subplots()
ax.imshow(m.array, cmap='terrain', origin='lower', interpolation='none')
ax.imshow(m.array, cmap='gist_earth', origin='lower', interpolation='none')
self._axes = ax
def DrawPath(self, path: Path, label: str = "Path"):
"""
Draw path on a map. Note that DrawMap has to be called first
"""
@ -119,6 +133,8 @@ class DFS:
name = "Depth First Search"
_map: Optional[Map]
_elapsed_time_ns: int
_visited_node_count: int
def __init__(self) -> None:
self._map = None
@ -128,11 +144,21 @@ class DFS:
def CalculatePath(self, start: Point2D, end: Point2D) -> Path:
assert self._map is not None, "SetMap must be called first"
self._map.ResetVisitedCount()
start_time = time.perf_counter_ns()
stop_time = time.perf_counter_ns()
self._elapsed_time_ns = stop_time - start_time
self._visited_node_count = self._map.GetVisitedCount()
return [(0,0), (5,5), (6,6), (1,9)]
def GetStats(self) -> (ElapsedTime_ns, VisitedNodeCount):
return 150.0, 42
return self._elapsed_time_ns, self._visited_node_count
def _visit(point: Point2D)
class BFS:
@ -170,17 +196,18 @@ def main():
v = Visualizer()
v.DrawMap(m)
for pt in path_finder_classes:
path_finder = pt()
for pfc in path_finder_classes:
path_finder = pfc()
path_finder.SetMap(m)
path = path_finder.CalculatePath(starting_point, end_point)
elapsed_time, visited_nodes = path_finder.GetStats()
print(f"{path_finder.name:22}: took {elapsed_time} ns, visited {visited_nodes} nodes")
v.DrawPath(path)
#
p = (9,1)
print(f"{m.IsPointValid(p)=}")
print(f"{m.GetCost(p)=}")
p = (1,9)
# print(f"{m.IsPointValid(p)=}")
# print(f"{m.GetCost(p)=}")
print(f"{m.GetNeighbours(p)}")
plt.show()