Implemented Map.GetNeighbours
This commit is contained in:
parent
a9be2e5bbc
commit
744ccaf478
@ -7,6 +7,7 @@
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Protocol, Optional
|
||||
|
||||
#
|
||||
@ -15,7 +16,7 @@ from typing import Protocol, Optional
|
||||
|
||||
type Point2D = tuple[int, int] # tuple(x, y)
|
||||
type Path = list[Point2D]
|
||||
type ElapsedTime_ns = float # nanoseconds
|
||||
type ElapsedTime_ns = int # nanoseconds
|
||||
type VisitedNodeCount = int
|
||||
|
||||
class Map:
|
||||
@ -44,8 +45,17 @@ class Map:
|
||||
y_in_bounds = (0 <= y < y_max)
|
||||
return x_in_bounds and y_in_bounds
|
||||
|
||||
def GetNeighbours(self) -> list[Point2D]:
|
||||
...
|
||||
def GetNeighbours(self, center_point: Point2D) -> list[Point2D]:
|
||||
points: list[Point2D] = []
|
||||
x_center, y_center = center_point
|
||||
for x in range(-1,2):
|
||||
for y in range(-1,2):
|
||||
if x == 0 and y == 0:
|
||||
continue
|
||||
p = (x + x_center, y + y_center)
|
||||
if self.IsPointValid(p):
|
||||
points.append(p)
|
||||
return points
|
||||
|
||||
def ResetVisitedCount(self) -> None:
|
||||
self._visited_nodes = 0
|
||||
@ -53,7 +63,10 @@ class Map:
|
||||
def GetVisitedCount(self) -> int:
|
||||
return self._visited_nodes
|
||||
|
||||
def GetCost(self, point: Point2D) -> float:
|
||||
def Visit(self, point: Point2D) -> float:
|
||||
"""
|
||||
Visit the node, returning its cost
|
||||
"""
|
||||
if not self.IsPointValid(point):
|
||||
raise ValueError("Point out of bounds")
|
||||
self._visited_nodes += 1
|
||||
@ -96,10 +109,11 @@ class Visualizer:
|
||||
def DrawMap(self, m: Map):
|
||||
M, N = m.array.shape
|
||||
_, ax = plt.subplots()
|
||||
ax.imshow(m.array, cmap='terrain', origin='lower', interpolation='none')
|
||||
ax.imshow(m.array, cmap='gist_earth', origin='lower', interpolation='none')
|
||||
self._axes = ax
|
||||
|
||||
def DrawPath(self, path: Path, label: str = "Path"):
|
||||
|
||||
"""
|
||||
Draw path on a map. Note that DrawMap has to be called first
|
||||
"""
|
||||
@ -119,6 +133,8 @@ class DFS:
|
||||
|
||||
name = "Depth First Search"
|
||||
_map: Optional[Map]
|
||||
_elapsed_time_ns: int
|
||||
_visited_node_count: int
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._map = None
|
||||
@ -128,11 +144,21 @@ class DFS:
|
||||
|
||||
def CalculatePath(self, start: Point2D, end: Point2D) -> Path:
|
||||
assert self._map is not None, "SetMap must be called first"
|
||||
self._map.ResetVisitedCount()
|
||||
|
||||
start_time = time.perf_counter_ns()
|
||||
|
||||
|
||||
stop_time = time.perf_counter_ns()
|
||||
|
||||
self._elapsed_time_ns = stop_time - start_time
|
||||
self._visited_node_count = self._map.GetVisitedCount()
|
||||
return [(0,0), (5,5), (6,6), (1,9)]
|
||||
|
||||
def GetStats(self) -> (ElapsedTime_ns, VisitedNodeCount):
|
||||
return 150.0, 42
|
||||
return self._elapsed_time_ns, self._visited_node_count
|
||||
|
||||
def _visit(point: Point2D)
|
||||
|
||||
class BFS:
|
||||
|
||||
@ -170,17 +196,18 @@ def main():
|
||||
v = Visualizer()
|
||||
v.DrawMap(m)
|
||||
|
||||
for pt in path_finder_classes:
|
||||
path_finder = pt()
|
||||
for pfc in path_finder_classes:
|
||||
path_finder = pfc()
|
||||
path_finder.SetMap(m)
|
||||
path = path_finder.CalculatePath(starting_point, end_point)
|
||||
elapsed_time, visited_nodes = path_finder.GetStats()
|
||||
print(f"{path_finder.name:22}: took {elapsed_time} ns, visited {visited_nodes} nodes")
|
||||
v.DrawPath(path)
|
||||
#
|
||||
p = (9,1)
|
||||
print(f"{m.IsPointValid(p)=}")
|
||||
print(f"{m.GetCost(p)=}")
|
||||
p = (1,9)
|
||||
# print(f"{m.IsPointValid(p)=}")
|
||||
# print(f"{m.GetCost(p)=}")
|
||||
print(f"{m.GetNeighbours(p)}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user