|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Optional, TypeAlias |
| 3 | +import heapq |
| 4 | + |
| 5 | +from PathPlanning.TimeBasedPathPlanning.Node import NodePath, Position |
| 6 | + |
| 7 | +AgentId: TypeAlias = int |
| 8 | + |
| 9 | +@dataclass |
| 10 | +class Constraint: |
| 11 | + position: Position |
| 12 | + time: int |
| 13 | + |
| 14 | +@dataclass |
| 15 | +class PathConstraint: |
| 16 | + constraint: Constraint |
| 17 | + shorter_path_agent: AgentId |
| 18 | + longer_path_agent: AgentId |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class ConstraintTreeNode: |
| 22 | + parent_idx = int |
| 23 | + constraint: tuple[AgentId, Constraint] |
| 24 | + |
| 25 | + paths: dict[AgentId, NodePath] |
| 26 | + cost: int |
| 27 | + |
| 28 | + def __lt__(self, other) -> bool: |
| 29 | + # TODO - this feels jank? |
| 30 | + return self.cost + self.constrained_path_cost() < other.cost + other.constrained_path_cost() |
| 31 | + |
| 32 | + def get_constraint_point(self) -> Optional[PathConstraint]: |
| 33 | + final_t = max(path.goal_reached_time() for path in self.paths) |
| 34 | + positions_at_time: dict[Position, AgentId] = {} |
| 35 | + for t in range(final_t + 1): |
| 36 | + # TODO: need to be REALLY careful that these agent ids are consitent |
| 37 | + for agent_id, path in self.paths.items(): |
| 38 | + position = path.get_position(t) |
| 39 | + if position is None: |
| 40 | + continue |
| 41 | + if position in positions_at_time: |
| 42 | + conflicting_agent_id = positions_at_time[position] |
| 43 | + this_agent_shorter = self.paths[agent_id].goal_reached_time() < self.paths[conflicting_agent_id].goal_reached_time() |
| 44 | + |
| 45 | + return PathConstraint( |
| 46 | + constraint=Constraint(position=position, time=t), |
| 47 | + shorter_path_agent= agent_id if this_agent_shorter else conflicting_agent_id, |
| 48 | + longer_path_agent= conflicting_agent_id if this_agent_shorter else agent_id |
| 49 | + ) |
| 50 | + return None |
| 51 | + |
| 52 | + def constrained_path_cost(self) -> int: |
| 53 | + constrained_path = self.paths[self.constraint[0]] |
| 54 | + return constrained_path.goal_reached_time() |
| 55 | + |
| 56 | +class ConstraintTree: |
| 57 | + # Child nodes have been created (Maps node_index to ConstraintTreeNode) |
| 58 | + expanded_nodes: dict[int, ConstraintTreeNode] |
| 59 | + # Need to solve and generate children |
| 60 | + nodes_to_expand: heapq[ConstraintTreeNode] |
| 61 | + |
| 62 | + solution: Optional[ConstraintTreeNode] = None |
| 63 | + |
| 64 | + def __init__(self, initial_solution: dict[AgentId, NodePath]): |
| 65 | + initial_cost = sum(path.goal_reached_time() for path in initial_solution.values()) |
| 66 | + heapq.heappush(self.nodes_to_expand, ConstraintTreeNode(constraints={}, paths=initial_solution, cost=initial_cost, parent_idx=-1)) |
| 67 | + |
| 68 | + def get_next_node_to_expand(self) -> Optional[ConstraintTreeNode]: |
| 69 | + if not self.nodes_to_expand: |
| 70 | + return None |
| 71 | + return heapq.heappop(self.nodes_to_expand) |
| 72 | + |
| 73 | + def add_node_to_tree(self, node: ConstraintTreeNode) -> bool: |
| 74 | + """ |
| 75 | + Add a node to the tree and generate children if needed. Returns true if the node is a solution, false otherwise. |
| 76 | + """ |
| 77 | + node_index = len(self.expanded_nodes) |
| 78 | + self.expanded_nodes[node_index] = node |
| 79 | + constraint_point = node.get_constraint_point() |
| 80 | + if constraint_point is None: |
| 81 | + # Don't need to add any constraints, this is a solution! |
| 82 | + self.solution = node |
| 83 | + return |
| 84 | + |
| 85 | + child_node1 = node |
| 86 | + child_node1.constraint = (constraint_point.shorter_path_agent, constraint_point.constraint) |
| 87 | + child_node1.parent_idx = node_index |
| 88 | + |
| 89 | + child_node2 = node |
| 90 | + child_node2.constraint = (constraint_point.longer_path_agent, constraint_point.constraint) |
| 91 | + child_node2.parent_idx = node_index |
| 92 | + |
| 93 | + heapq.heappush(self.nodes_to_expand, child_node1) |
| 94 | + heapq.heappush(self.nodes_to_expand, child_node2) |
| 95 | + |
| 96 | + def get_ancestor_constraints(self, parent_index: int): |
| 97 | + """ |
| 98 | + Get the constraints that were applied to the parent node to generate this node. |
| 99 | + """ |
| 100 | + constraints = [] |
| 101 | + while parent_index != -1: |
| 102 | + node = self.expanded_nodes[parent_index] |
| 103 | + if node.constraint is not None: |
| 104 | + constraints.append(node.constraint) |
| 105 | + parent_index = node.parent_idx |
| 106 | + return constraints |
0 commit comments