Implemented Map.GetNeighbours
This commit is contained in:
parent
a9be2e5bbc
commit
744ccaf478
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
from typing import Protocol, Optional
|
from typing import Protocol, Optional
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -15,7 +16,7 @@ from typing import Protocol, Optional
|
|||||||
|
|
||||||
type Point2D = tuple[int, int] # tuple(x, y)
|
type Point2D = tuple[int, int] # tuple(x, y)
|
||||||
type Path = list[Point2D]
|
type Path = list[Point2D]
|
||||||
type ElapsedTime_ns = float # nanoseconds
|
type ElapsedTime_ns = int # nanoseconds
|
||||||
type VisitedNodeCount = int
|
type VisitedNodeCount = int
|
||||||
|
|
||||||
class Map:
|
class Map:
|
||||||
@ -44,8 +45,17 @@ class Map:
|
|||||||
y_in_bounds = (0 <= y < y_max)
|
y_in_bounds = (0 <= y < y_max)
|
||||||
return x_in_bounds and y_in_bounds
|
return x_in_bounds and y_in_bounds
|
||||||
|
|
||||||
def GetNeighbours(self) -> list[Point2D]:
|
def GetNeighbours(self, center_point: Point2D) -> list[Point2D]:
|
||||||
...
|
points: list[Point2D] = []
|
||||||
|
x_center, y_center = center_point
|
||||||
|
for x in range(-1,2):
|
||||||
|
for y in range(-1,2):
|
||||||
|
if x == 0 and y == 0:
|
||||||
|
continue
|
||||||
|
p = (x + x_center, y + y_center)
|
||||||
|
if self.IsPointValid(p):
|
||||||
|
points.append(p)
|
||||||
|
return points
|
||||||
|
|
||||||
def ResetVisitedCount(self) -> None:
|
def ResetVisitedCount(self) -> None:
|
||||||
self._visited_nodes = 0
|
self._visited_nodes = 0
|
||||||
@ -53,7 +63,10 @@ class Map:
|
|||||||
def GetVisitedCount(self) -> int:
|
def GetVisitedCount(self) -> int:
|
||||||
return self._visited_nodes
|
return self._visited_nodes
|
||||||
|
|
||||||
def GetCost(self, point: Point2D) -> float:
|
def Visit(self, point: Point2D) -> float:
|
||||||
|
"""
|
||||||
|
Visit the node, returning its cost
|
||||||
|
"""
|
||||||
if not self.IsPointValid(point):
|
if not self.IsPointValid(point):
|
||||||
raise ValueError("Point out of bounds")
|
raise ValueError("Point out of bounds")
|
||||||
self._visited_nodes += 1
|
self._visited_nodes += 1
|
||||||
@ -96,10 +109,11 @@ class Visualizer:
|
|||||||
def DrawMap(self, m: Map):
|
def DrawMap(self, m: Map):
|
||||||
M, N = m.array.shape
|
M, N = m.array.shape
|
||||||
_, ax = plt.subplots()
|
_, ax = plt.subplots()
|
||||||
ax.imshow(m.array, cmap='terrain', origin='lower', interpolation='none')
|
ax.imshow(m.array, cmap='gist_earth', origin='lower', interpolation='none')
|
||||||
self._axes = ax
|
self._axes = ax
|
||||||
|
|
||||||
def DrawPath(self, path: Path, label: str = "Path"):
|
def DrawPath(self, path: Path, label: str = "Path"):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Draw path on a map. Note that DrawMap has to be called first
|
Draw path on a map. Note that DrawMap has to be called first
|
||||||
"""
|
"""
|
||||||
@ -119,6 +133,8 @@ class DFS:
|
|||||||
|
|
||||||
name = "Depth First Search"
|
name = "Depth First Search"
|
||||||
_map: Optional[Map]
|
_map: Optional[Map]
|
||||||
|
_elapsed_time_ns: int
|
||||||
|
_visited_node_count: int
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._map = None
|
self._map = None
|
||||||
@ -128,11 +144,21 @@ class DFS:
|
|||||||
|
|
||||||
def CalculatePath(self, start: Point2D, end: Point2D) -> Path:
|
def CalculatePath(self, start: Point2D, end: Point2D) -> Path:
|
||||||
assert self._map is not None, "SetMap must be called first"
|
assert self._map is not None, "SetMap must be called first"
|
||||||
|
self._map.ResetVisitedCount()
|
||||||
|
|
||||||
|
start_time = time.perf_counter_ns()
|
||||||
|
|
||||||
|
|
||||||
|
stop_time = time.perf_counter_ns()
|
||||||
|
|
||||||
|
self._elapsed_time_ns = stop_time - start_time
|
||||||
|
self._visited_node_count = self._map.GetVisitedCount()
|
||||||
return [(0,0), (5,5), (6,6), (1,9)]
|
return [(0,0), (5,5), (6,6), (1,9)]
|
||||||
|
|
||||||
def GetStats(self) -> (ElapsedTime_ns, VisitedNodeCount):
|
def GetStats(self) -> (ElapsedTime_ns, VisitedNodeCount):
|
||||||
return 150.0, 42
|
return self._elapsed_time_ns, self._visited_node_count
|
||||||
|
|
||||||
|
def _visit(point: Point2D)
|
||||||
|
|
||||||
class BFS:
|
class BFS:
|
||||||
|
|
||||||
@ -170,17 +196,18 @@ def main():
|
|||||||
v = Visualizer()
|
v = Visualizer()
|
||||||
v.DrawMap(m)
|
v.DrawMap(m)
|
||||||
|
|
||||||
for pt in path_finder_classes:
|
for pfc in path_finder_classes:
|
||||||
path_finder = pt()
|
path_finder = pfc()
|
||||||
path_finder.SetMap(m)
|
path_finder.SetMap(m)
|
||||||
path = path_finder.CalculatePath(starting_point, end_point)
|
path = path_finder.CalculatePath(starting_point, end_point)
|
||||||
elapsed_time, visited_nodes = path_finder.GetStats()
|
elapsed_time, visited_nodes = path_finder.GetStats()
|
||||||
print(f"{path_finder.name:22}: took {elapsed_time} ns, visited {visited_nodes} nodes")
|
print(f"{path_finder.name:22}: took {elapsed_time} ns, visited {visited_nodes} nodes")
|
||||||
v.DrawPath(path)
|
v.DrawPath(path)
|
||||||
#
|
#
|
||||||
p = (9,1)
|
p = (1,9)
|
||||||
print(f"{m.IsPointValid(p)=}")
|
# print(f"{m.IsPointValid(p)=}")
|
||||||
print(f"{m.GetCost(p)=}")
|
# print(f"{m.GetCost(p)=}")
|
||||||
|
print(f"{m.GetNeighbours(p)}")
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user