diff --git a/src/software/evaluation/loggers/BUILD b/src/software/evaluation/loggers/BUILD new file mode 100644 index 0000000000..384b7f214f --- /dev/null +++ b/src/software/evaluation/loggers/BUILD @@ -0,0 +1,9 @@ +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "stats_logger", + srcs = ["stats_logger.py"], + deps = [ + "//software/evaluation/trackers:tracker", + ], +) diff --git a/src/software/evaluation/loggers/stats_logger.py b/src/software/evaluation/loggers/stats_logger.py new file mode 100644 index 0000000000..c8d3264591 --- /dev/null +++ b/src/software/evaluation/loggers/stats_logger.py @@ -0,0 +1,170 @@ +import os + +from software.evaluation.trackers import ( + PossessionTracker, + ShotTracker, + PassTracker, + TrackerBuilder, + RefereeTracker, + GoalieTracker, +) +from dataclasses import dataclass +from software.thunderscope.proto_unix_io import ProtoUnixIO +from software.thunderscope.constants import RuntimeManagerConstants +from software.evaluation.logs.event_log import Team as TeamEnum, EventLog +import logging +from proto.import_all_protos import * +import queue + + +@dataclass +class FSStats: + """Stats for how well a FullSystem is performing""" + + num_yellow_cards: int = 0 + num_red_cards: int = 0 + num_scores: int = 0 + + num_shots_on_net: int = 0 + num_enemy_shots_blocked: int = 0 + + +class StatsLogger: + # From GoalieTacticConfig + INCOMING_SHOT_MIN_VELOCITY = 0.2 + + EVENT_BUFFER_SIZE = 100 + + def __init__( + self, + proto_unix_io: ProtoUnixIO, + friendly_colour_yellow: bool, + out_file_name: str | None = None, + buffer_size: int = 5, + record_enemy_stats: bool = False, + ): + """Initializes the FullSystem Stats Tracker + + :param friendly_colour_yellow: if the friendly colour is yellow + :param out_file_name: name of file to write stats to. + If None, uses the value from constants + :param buffer_size: the buffer size for protocol buffers + :param record_enemy_stats: if this should record both friendly and enemy stats or just friendly + """ + self.friendly_colour_yellow = friendly_colour_yellow + + self.events_file_path = os.path.join( + RuntimeManagerConstants.RUNTIME_EVENTS_DIRECTORY_PATH, + RuntimeManagerConstants.RUNTIME_EVENTS_FILE + if out_file_name is None + else out_file_name, + ) + # initialized in setup() + self.events_file_handle = None + + self.event_queue = queue.Queue(self.EVENT_BUFFER_SIZE) + + # flag to turn off logging stats if needed + self.logging_enabled = True + + self.tracker = ( + TrackerBuilder( + proto_unix_io=proto_unix_io, + from_team=( + TeamEnum.YELLOW if self.friendly_colour_yellow else TeamEnum.BLUE + ), + event_queue=self.event_queue, + buffer_size=buffer_size, + ) + .add_tracker(PassTracker) + .add_tracker(ShotTracker) + .add_tracker(PossessionTracker) + .add_tracker( + RefereeTracker, + friendly_color_yellow=self.friendly_colour_yellow, + toggle_logging=self._toggle_logging, + ) + .add_tracker(GoalieTracker, for_friendly=True) + ) + + self.record_enemy_stats = record_enemy_stats + if self.record_enemy_stats: + self.enemy_tracker = ( + TrackerBuilder( + proto_unix_io=proto_unix_io, + from_team=( + TeamEnum.YELLOW + if self.friendly_colour_yellow + else TeamEnum.BLUE + ), + for_team=( + TeamEnum.BLUE + if self.friendly_colour_yellow + else TeamEnum.YELLOW + ), + event_queue=self.event_queue, + buffer_size=buffer_size, + ) + .add_tracker( + RefereeTracker, + friendly_color_yellow=(not self.friendly_colour_yellow), + toggle_logging=self._toggle_logging, + ) + .add_tracker(GoalieTracker, for_friendly=False) + ) + + def refresh(self) -> None: + """Refreshes the events for the game so far""" + self.tracker.refresh() + + if not self.events_file_handle: + return + + while not self.event_queue.empty(): + try: + # Get item without blocking + event = self.event_queue.get_nowait() + + self._write_event_to_file(event) + except queue.Empty: + return + + def __enter__(self): + """Sets up the file resources for logging + Creates any missing directories and stores the file handle + """ + # create temp stats directory if it doesn't exist + os.makedirs(os.path.dirname(self.events_file_path), exist_ok=True) + + self.events_file_handle = open(self.events_file_path, "a") + + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Writes all logs back to file, and cleans up any created file resources after logging""" + if self.events_file_handle: + self.events_file_handle.flush() + self.events_file_handle.close() + + def _toggle_logging(self, should_log: bool) -> None: + """Turns logging off or on based on the given boolean + + ;param should_log: True if logging should continue, False if not + """ + self.logging_enabled = should_log + + def _write_event_to_file(self, event: EventLog) -> None: + """Write the given stats to the given file + + :param event: the event to write + """ + if not self.events_file_handle: + return + + try: + csv_row = event.to_csv_row() + self.events_file_handle.write(csv_row + "\n") + self.events_file_handle.flush() + + except (IOError, FileNotFoundError, PermissionError): + logging.warning("Failed to write event to file") diff --git a/src/software/evaluation/logs/BUILD b/src/software/evaluation/logs/BUILD new file mode 100644 index 0000000000..c5e526b184 --- /dev/null +++ b/src/software/evaluation/logs/BUILD @@ -0,0 +1,39 @@ +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "log_interface", + srcs = ["log_interface.py"], + deps = [ + "//proto:import_all_protos", + "//software/thunderscope:time_provider", + ], +) + +py_library( + name = "world_state_log", + srcs = ["world_state_log.py"], + data = [ + "//software:py_constants.so", + ], + deps = [ + ":log_interface", + ], +) + +py_library( + name = "event_log", + srcs = ["event_log.py"], + deps = [ + ":log_interface", + ":world_state_log", + ], +) + +py_library( + name = "logs", + deps = [ + ":event_log", + ":log_interface", + ":world_state_log", + ], +) diff --git a/src/software/evaluation/logs/event_log.py b/src/software/evaluation/logs/event_log.py new file mode 100644 index 0000000000..b79b2e3b0f --- /dev/null +++ b/src/software/evaluation/logs/event_log.py @@ -0,0 +1,107 @@ +from __future__ import annotations +from dataclasses import dataclass +from enum import StrEnum, auto +from proto.import_all_protos import * +from typing import Any, override +from software.evaluation.logs.log_interface import TimestampedEvalLog +from software.evaluation.logs.world_state_log import WorldStateLog + + +class EventType(StrEnum): + """Enum for the different types of events we want to track""" + + PASS = auto() + SHOT_ON_GOAL = auto() + ENEMY_SHOT_ON_GOAL = auto() + SHOT_BLOCKED = auto() + FRIENDLY_POSSESSION_START = auto() + FRIENDLY_POSSESSION_END = auto() + ENEMY_POSSESSION_START = auto() + ENEMY_POSSESSION_END = auto() + GAME_START = auto() + GAME_END = auto() + GOAL_SCORED = auto() + YELLOW_CARD = auto() + RED_CARD = auto() + + +class Team(StrEnum): + """The teams present in the game""" + + BLUE = auto() + YELLOW = auto() + + +@dataclass(kw_only=True) +class EventLog(TimestampedEvalLog): + """Represents a single event being tracked, where and for whom the event is, + and the game state at the time of the event + """ + + event_type: EventType + from_team: Team + for_team: Team + world_state_log: WorldStateLog + + num_cols = TimestampedEvalLog.get_num_cols() + 3 + WorldStateLog.get_num_cols() + + @staticmethod + def from_world( + world_msg: World, event_type: EventType, from_team: Team, for_team: Team + ) -> EventLog: + """Creates an EventLog from a world protobuf message + + :param world_msg: the world object containing the state of the game + :param event_type: the type of event being recorded + :param from_team: the team that the event is coming from + :param for_team: the team that the event is for + :return: a fully populated EventLog including world state + """ + world_state_log = WorldStateLog.from_world(world_msg=world_msg) + + return EventLog( + event_type=event_type, + from_team=from_team, + for_team=for_team, + world_state_log=world_state_log, + ) + + @override + @classmethod + def get_num_cols(cls) -> int: + return EventLog.num_cols + + @override + def to_array(self) -> list[Any]: + return ( + super().to_array() + + [ + self.event_type.value, + self.from_team.value, + self.for_team.value, + ] + + self.world_state_log.to_array() + ) + + @staticmethod + @override + def from_csv_row(row_iter: Iterator[str]) -> EventLog | None: + """Parses a full CSV row into an EventLog.""" + timestamp = float(next(row_iter)) + + event_type = EventType(next(row_iter)) + from_team = Team(next(row_iter)) + for_team = Team(next(row_iter)) + + world_state = WorldStateLog.from_csv_row(row_iter) + + if not world_state: + return None + + return EventLog( + timestamp=timestamp, + event_type=event_type, + from_team=from_team, + for_team=for_team, + world_state_log=world_state, + ) diff --git a/src/software/evaluation/logs/log_interface.py b/src/software/evaluation/logs/log_interface.py new file mode 100644 index 0000000000..728af35adb --- /dev/null +++ b/src/software/evaluation/logs/log_interface.py @@ -0,0 +1,85 @@ +from __future__ import annotations +from abc import abstractmethod, ABC +from dataclasses import dataclass, field +from proto.import_all_protos import * +from typing import Iterator, Any, override +from google.protobuf.descriptor import Descriptor, FieldDescriptor +from software.thunderscope.time_provider import time_provider_instance + + +def count_primitive_fields(descriptor: Descriptor): + """Recursively counts the number of primitive fields in a Protobuf message + using its descriptor. + + :param message: the message descriptor to count all leaf-level primitive fields for + :return: the count of primitive fields + """ + count = 0 + + for field in descriptor.fields: + # Check if the field is a nested message + if field.type == FieldDescriptor.TYPE_MESSAGE: + # Get the nested message class to recurse into its descriptor + nested_message = field.message_type + # Recurse using the nested message's descriptor + count += count_primitive_fields(nested_message) + else: + # It's a primitive type (double, float, int, bool, string, etc.) + count += 1 + return count + + +class IEvalLog(ABC): + @classmethod + @abstractmethod + def get_num_cols(cls) -> int: + """Gets the number of columns present in this log""" + raise NotImplementedError("Please use the appropriate subclass of log!") + + @abstractmethod + def to_array(self) -> list[Any]: + """Converts this log to an array of elements""" + raise NotImplementedError("Please use the appropriate subclass of log!") + + def to_csv_row(self): + """Converts this log into a Comma Separated Values string + + :return: a string of values separated by commas + """ + row_array = self.to_array() + assert len(row_array) == type(self).get_num_cols() + + return ",".join([str(elem) for elem in row_array]) + + @staticmethod + @abstractmethod + def from_csv_row(row_iter: Iterator[str], **kwargs) -> IEvalLog | None: + """Converts a CSV row into an instance of this log + + :param row_iter: an iterator representing a csv row, which returns elements one by one + :param **kwargs: any extra arguments needed for this log not present in the csv row + """ + raise NotImplementedError("Please use the appropriate subclass of log!") + + +@dataclass +class TimestampedEvalLog(IEvalLog): + timestamp: float = field(default_factory=time_provider_instance.elapsed_time_ns) + + def get_timestamp(self) -> float: + """Get this log's timestamp""" + return self.timestamp + + @classmethod + @override + def get_num_cols(cls) -> int: + return 1 + + @override + def to_array(self) -> list[Any]: + return [self.timestamp] + + @staticmethod + @override + def from_csv_row(row_iter: Iterator[str], **kwargs) -> TimestampedEvalLog | None: + return TimestampedEvalLog(timestamp=float(next(row_iter))) diff --git a/src/software/evaluation/logs/world_state_log.py b/src/software/evaluation/logs/world_state_log.py new file mode 100644 index 0000000000..5efc11f915 --- /dev/null +++ b/src/software/evaluation/logs/world_state_log.py @@ -0,0 +1,209 @@ +from __future__ import annotations +from dataclasses import dataclass +from proto.import_all_protos import * +from software.evaluation.logs.log_interface import IEvalLog, count_primitive_fields +from software.py_constants import DIV_B_NUM_ROBOTS +from typing import Any, Iterator, override + + +@dataclass +class RobotLog(IEvalLog): + """Represents a single robot on the field, with ID and current state.""" + + id: int + state: RobotState + + num_cols: int = count_primitive_fields(RobotState.DESCRIPTOR) + + @classmethod + @override + def get_num_cols(cls) -> int: + return RobotLog.num_cols + + def get_position(self) -> list[float]: + """Returns the current ball position as a [float, float] array + represnting x, y coordinates + """ + return [ + self.state.global_position.x_meters, + self.state.global_position.y_meters, + ] + + @override + def to_array(self) -> list[Any]: + return self.get_position() + [ + self.state.global_orientation.radians, + self.state.global_velocity.x_component_meters, + self.state.global_velocity.y_component_meters, + self.state.global_angular_velocity.radians_per_second, + ] + + @staticmethod + @override + def from_csv_row(row_iter: Iterator[str], id: int = 0) -> RobotLog | None: + """Parses a full CSV row into an RobotLog + + :param id: the id of the robot + :return: the RobotLog object represented by the csv row + """ + data = [next(row_iter) for _ in range(RobotLog.num_cols)] + + # for missing robots + if all(val == "None" for val in data): + return None + + state = RobotState( + global_position=Point(x_meters=float(data[0]), y_meters=float(data[1])), + global_orientation=Angle(radians=float(data[2])), + global_velocity=Vector( + x_component_meters=float(data[3]), y_component_meters=float(data[4]) + ), + global_angular_velocity=AngularVelocity(radians_per_second=float(data[5])), + ) + + return RobotLog(id=id, state=state) + + +@dataclass +class BallLog(IEvalLog): + """Represents a single ball on the field.""" + + state: BallState + + num_cols: int = count_primitive_fields(BallState.DESCRIPTOR) - 1 + + def get_position(self) -> list[float]: + """Returns the current ball position as a [float, float] array + represnting x, y coordinates + """ + return [ + self.state.global_position.x_meters, + self.state.global_position.y_meters, + ] + + @classmethod + @override + def get_num_cols(cls) -> int: + return BallLog.num_cols + + @override + def to_array(self) -> list[Any]: + return self.get_position() + [ + self.state.global_velocity.x_component_meters, + self.state.global_velocity.y_component_meters, + ] + + @staticmethod + @override + def from_csv_row(row_iter: Iterator[str]) -> BallLog | None: + """Parses a full CSV row into an BallLog""" + data = [next(row_iter) for _ in range(BallLog.num_cols)] + + state = BallState( + global_position=Point(x_meters=float(data[0]), y_meters=float(data[1])), + global_velocity=Vector( + x_component_meters=float(data[2]), y_component_meters=float(data[3]) + ), + distance_from_ground=0.0, + ) + + return BallLog(state=state) + + +@dataclass +class WorldStateLog(IEvalLog): + """Represents the current state of the world""" + + ball_state: BallLog + friendly_robots: list[RobotLog] + enemy_robots: list[RobotLog] + + num_cols: int = ( + DIV_B_NUM_ROBOTS * RobotLog.get_num_cols() * 2 + BallLog.get_num_cols() + ) + + @staticmethod + def from_world(world_msg: World) -> WorldStateLog: + """Creates a WorldStateLog from a world protobuf message + + :param world_msg: the world object containing the state of the game + :return: a fully populated WorldStateLog including ball and robot states + """ + ball_state = BallLog(state=world_msg.ball.current_state) + + friendly_robots = [ + RobotLog(id=robot.id, state=robot.current_state) + for robot in world_msg.friendly_team.team_robots + ] + enemy_robots = [ + RobotLog(id=robot.id, state=robot.current_state) + for robot in world_msg.enemy_team.team_robots + ] + + return WorldStateLog( + ball_state=ball_state, + friendly_robots=friendly_robots, + enemy_robots=enemy_robots, + ) + + @classmethod + @override + def get_num_cols(cls) -> int: + return WorldStateLog.num_cols + + def robots_to_array(self, robots: list[RobotLog]) -> list[Any]: + """Serializes robots into flattened columns within a list + + :param robot_states: the list of RobotState objects to flatten + :return: a list of values representing the states for all robots + """ + num_cols_per_robot = RobotLog.get_num_cols() + + # robot state columns will be added based on robot id + robot_state_map = {robot.id: robot for robot in robots} + + robots_array = [] + + # Add friendly robots: [r1_data, r2_data...] based on id + for idx in range(DIV_B_NUM_ROBOTS): + new_state = [] + + if idx not in robot_state_map: + new_state = [None for _ in range(num_cols_per_robot)] + else: + robot_state = robot_state_map[idx] + new_state = robot_state.to_array() + + robots_array.extend(new_state) + + return robots_array + + @override + def to_array(self) -> list[Any]: + return ( + self.ball_state.to_array() + + self.robots_to_array(self.friendly_robots) + + self.robots_to_array(self.enemy_robots) + ) + + @staticmethod + def from_csv_row(row_iter: Iterator[str]) -> WorldStateLog | None: + ball_state = BallLog.from_csv_row(row_iter) + + friendly_robots = [] + for id in range(DIV_B_NUM_ROBOTS): + robot = RobotLog.from_csv_row(row_iter, id=id) + if robot: + friendly_robots.append(robot) + + enemy_robots = [] + for id in range(DIV_B_NUM_ROBOTS): + robot = RobotLog.from_csv_row(row_iter, id=id) + if robot: + enemy_robots.append(robot) + + return WorldStateLog( + ball_state=ball_state, + friendly_robots=friendly_robots, + enemy_robots=enemy_robots, + ) diff --git a/src/software/thunderscope/log/trackers/BUILD b/src/software/evaluation/trackers/BUILD similarity index 93% rename from src/software/thunderscope/log/trackers/BUILD rename to src/software/evaluation/trackers/BUILD index 79a4365113..44030df25f 100644 --- a/src/software/thunderscope/log/trackers/BUILD +++ b/src/software/evaluation/trackers/BUILD @@ -10,7 +10,6 @@ py_library( "kick_tracker.py", "possession_tracker.py", "referee_tracker.py", - "tracked_event.py", "tracker.py", "tracker_builder.py", ], @@ -20,6 +19,7 @@ py_library( ], visibility = ["//visibility:public"], deps = [ + "//software/evaluation/logs", "//software/thunderscope:time_provider", ], ) diff --git a/src/software/evaluation/trackers/__init__.py b/src/software/evaluation/trackers/__init__.py new file mode 100644 index 0000000000..3c15538776 --- /dev/null +++ b/src/software/evaluation/trackers/__init__.py @@ -0,0 +1,14 @@ +from software.evaluation.trackers.kick_tracker import ShotTracker, PassTracker +from software.evaluation.trackers.possession_tracker import PossessionTracker +from software.evaluation.trackers.tracker_builder import TrackerBuilder +from software.evaluation.trackers.referee_tracker import RefereeTracker +from software.evaluation.trackers.goalie_tracker import GoalieTracker + +__all__ = [ + "PossessionTracker", + "ShotTracker", + "PassTracker", + "TrackerBuilder", + "RefereeTracker", + "GoalieTracker", +] diff --git a/src/software/thunderscope/log/trackers/goalie_tracker.py b/src/software/evaluation/trackers/goalie_tracker.py similarity index 68% rename from src/software/thunderscope/log/trackers/goalie_tracker.py rename to src/software/evaluation/trackers/goalie_tracker.py index 36ee9af432..0eaf1c990c 100644 --- a/src/software/thunderscope/log/trackers/goalie_tracker.py +++ b/src/software/evaluation/trackers/goalie_tracker.py @@ -1,10 +1,11 @@ -from software.thunderscope.log.trackers.tracker import Tracker -from typing import Callable, override +from software.evaluation.trackers.tracker import Tracker +from typing import override from proto.import_all_protos import * -from software.thunderscope.thread_safe_buffer import ThreadSafeBuffer from software.thunderscope.proto_unix_io import ProtoUnixIO import software.python_bindings as tbots_cpp from software.py_constants import ROBOT_MAX_RADIUS_METERS +from software.evaluation.logs.event_log import EventType, Team +import queue class GoalieTracker(Tracker): @@ -17,56 +18,64 @@ class GoalieTracker(Tracker): def __init__( self, + proto_unix_io: ProtoUnixIO, + from_team: Team, + for_team: Team, + event_queue: queue.Queue, for_friendly: bool, - callback: Callable[[bool, bool], None], - buffer_size: int = 5, + **kwargs, ): """Initializes the Goalie tracker :param for_friendly: if we should track shots on goal for the friendly or enemy team - :param callback: function to call when there is a new shot on goal - called with 2 booleans: - - If there is a shot on goal right now - - If there was a shot on goal before - lets us track new shots on goal + when shots are blocked - :param buffer_size: buffer size for the tracker's io + :param proto_unix_io: the proto unix io to get the game state from + :param from_team: the team that this tracker is tracking from (events are from this team) + :param for_team: the team that this tracker is tracking for (events are for this team) + default is same as the from_team, but can be different + :param event_queue: the queue to write events to """ - super().__init__(callback=callback, buffer_size=buffer_size) - - self.world_buffer = ThreadSafeBuffer(buffer_size, World) + super().__init__( + proto_unix_io=proto_unix_io, + from_team=from_team, + for_team=for_team, + event_queue=event_queue, + **kwargs, + ) self.for_friendly = for_friendly self.is_shot_incoming = False @override - def set_proto_unix_io(self, proto_unix_io: ProtoUnixIO) -> None: - super().set_proto_unix_io( - proto_unix_io, - [ - (World, self.world_buffer), - ], - ) - - @override - def refresh(self): - """Refresh and update the callback with the latest shot on goal information""" - world_msg = self.world_buffer.get(block=False, return_cached=True) - - if not world_msg: + def refresh_tracker(self) -> None: + """Refresh and log any new shots on goal""" + if self.cached_world is None: return - world = tbots_cpp.World(world_msg) - latest_is_shot_incoming = self._is_goal_shot_incoming( - world.ball(), world.field(), for_friendly=self.for_friendly + self.cached_world.ball(), + self.cached_world.field(), + for_friendly=self.for_friendly, ) - if self.callback: - self.callback(latest_is_shot_incoming, self.is_shot_incoming) + self._log_incoming_shot(latest_is_shot_incoming) self.is_shot_incoming = latest_is_shot_incoming + def _log_incoming_shot(self, new_shot_incoming): + event_type = None + + if not new_shot_incoming and self.is_shot_incoming: + event_type = EventType.SHOT_BLOCKED + + if new_shot_incoming and not self.is_shot_incoming: + event_type = EventType.ENEMY_SHOT_ON_GOAL + + if not event_type: + return + + self.write_event(event_type=event_type) + def _get_goal_shot_region( self, field: tbots_cpp.Field, for_friendly: bool ) -> tbots_cpp.Rectangle: diff --git a/src/software/thunderscope/log/trackers/kick_tracker.py b/src/software/evaluation/trackers/kick_tracker.py similarity index 66% rename from src/software/thunderscope/log/trackers/kick_tracker.py rename to src/software/evaluation/trackers/kick_tracker.py index 0d7c6d0153..3c2d683aec 100644 --- a/src/software/thunderscope/log/trackers/kick_tracker.py +++ b/src/software/evaluation/trackers/kick_tracker.py @@ -3,10 +3,12 @@ import software.python_bindings as tbots_cpp from proto.visualization_pb2 import AttackerVisualization from proto.import_all_protos import * -from typing import Callable, Any, override +from typing import override from software.thunderscope.thread_safe_buffer import ThreadSafeBuffer -from software.thunderscope.log.trackers.tracker import Tracker +from software.evaluation.trackers.tracker import Tracker from software.thunderscope.proto_unix_io import ProtoUnixIO +from software.evaluation.logs.event_log import EventType, Team +import queue class KickTracker(Tracker): @@ -37,31 +39,40 @@ class KickTracker(Tracker): def __init__( self, - callback: Optional[Callable[[Any], None]] = None, - buffer_size: int = 5, + proto_unix_io: ProtoUnixIO, + from_team: Team, + for_team: Team, + event_queue: queue.Queue, + **kwargs, ): - """Initialize the kick tracker - :param callback: an optional callback to call when there's a kick + """Initializes the KickTracker + + :param proto_unix_io: the proto unix io to get the game state from + :param from_team: the team that this tracker is tracking from (events are from this team) + :param for_team: the team that this tracker is tracking for (events are for this team) + default is same as the from_team, but can be different + :param event_queue: the queue to write events to """ - super().__init__(callback=callback, buffer_size=buffer_size) + super().__init__( + proto_unix_io=proto_unix_io, + from_team=from_team, + for_team=for_team, + event_queue=event_queue, + **kwargs, + ) + self.latest_kick_angle: tbots_cpp.Angle = tbots_cpp.Angle() self.kick_taken = False self.attacker_vis_buffer = ThreadSafeBuffer( self.buffer_size, AttackerVisualization ) - self.world_buffer = ThreadSafeBuffer(self.buffer_size, World) - - @override - def set_proto_unix_io(self, proto_unix_io: ProtoUnixIO) -> None: - super().set_proto_unix_io( - proto_unix_io, - [ - (AttackerVisualization, self.attacker_vis_buffer), - (World, self.world_buffer), - ], + self.proto_unix_io.register_observer( + AttackerVisualization, self.attacker_vis_buffer ) + self.curr_pass = None + def _get_new_kick_angle( self, origin: Point, target: Point, latest_angle: tbots_cpp.Angle ) -> Optional[tbots_cpp.Angle]: @@ -94,22 +105,20 @@ def _get_new_kick_angle( return None - def refresh(self) -> None: + @override + def refresh_tracker(self) -> None: """Refreshes the tracker by getting the current state of the world and the latest attacker visualization """ - attacker_vis_msg = self.attacker_vis_buffer.get(block=False) - - if not attacker_vis_msg: + if self.cached_world is None: return - world_msg = self.world_buffer.get(block=False, return_cached=True) + attacker_vis_msg = self.attacker_vis_buffer.get(block=False) - if world_msg is None: + if not attacker_vis_msg: return - world = tbots_cpp.World(world_msg) - self._refresh_kicks(attacker_vis_msg, world) + self._refresh_kicks(attacker_vis_msg, self.cached_world) def _refresh_kicks( self, attacker_vis_msg: AttackerVisualization, world: tbots_cpp.World @@ -120,18 +129,38 @@ def _refresh_kicks( class PassTracker(KickTracker): """Tracker for tracking the attacker's passes""" - def __init__(self, callback: Callable[[Pass], None] = None, buffer_size: int = 5): - """Initialize the pass tracker - :param callback: an optional callback to call when there's a pass + def __init__( + self, + proto_unix_io: ProtoUnixIO, + from_team: Team, + for_team: Team, + event_queue: queue.Queue, + **kwargs, + ): + """Initializes the PassTracker + + :param proto_unix_io: the proto unix io to get the game state from + :param from_team: the team that this tracker is tracking from (events are from this team) + :param for_team: the team that this tracker is tracking for (events are for this team) + default is same as the from_team, but can be different + :param event_queue: the queue to write events to """ - super().__init__(callback=callback, buffer_size=buffer_size) + super().__init__( + proto_unix_io=proto_unix_io, + from_team=from_team, + for_team=for_team, + event_queue=event_queue, + **kwargs, + ) @override def _refresh_kicks( self, attacker_vis_msg: AttackerVisualization, world: tbots_cpp.World ) -> None: """Refreshes the pass tracker with the new attacker visualization - and the latest state of the world + and the latest state of the world. + + Logs any new pass events. :param attacker_vis_msg: the latest attacker visualization message :param world: the current world state @@ -151,6 +180,7 @@ def _refresh_kicks( if new_pass_angle is not None: self.latest_kick_angle = new_pass_angle self.kick_taken = False + self.curr_pass = attacker_vis_msg.pass_ ball = world.ball() @@ -160,27 +190,46 @@ def _refresh_kicks( self.MIN_SHOT_SPEED, self.MAX_KICK_ANGLE_DIFFERENCE, ): + self.write_event(event_type=EventType.PASS) self.kick_taken = True - - if self.callback: - self.callback(attacker_vis_msg.pass_) + self.curr_pass = None class ShotTracker(KickTracker): """Tracker for tracking the attacker's shots on goal""" - def __init__(self, callback: Callable[[Shot], None] = None, buffer_size: int = 5): - """Initialize the shot tracker - :param callback: an optional callback to call when there's a shot + def __init__( + self, + proto_unix_io: ProtoUnixIO, + from_team: Team, + for_team: Team, + event_queue: queue.Queue, + **kwargs, + ): + """Initializes the ShotTracker + + :param proto_unix_io: the proto unix io to get the game state from + :param from_team: the team that this tracker is tracking from (events are from this team) + :param for_team: the team that this tracker is tracking for (events are for this team) + default is same as the from_team, but can be different + :param event_queue: the queue to write events to """ - super().__init__(callback=callback, buffer_size=buffer_size) + super().__init__( + proto_unix_io=proto_unix_io, + from_team=from_team, + for_team=for_team, + event_queue=event_queue, + **kwargs, + ) @override def _refresh_kicks( self, attacker_vis_msg: AttackerVisualization, world: tbots_cpp.World ) -> None: """Refreshes the shot tracker with the new attacker visualization - and the latest state of the world + and the latest state of the world. + + Logs any new shot events. :param attacker_vis_msg: the latest attacker visualization message :param world: the current world state @@ -213,5 +262,4 @@ def _refresh_kicks( ): self.kick_taken = True - if self.callback: - self.callback(attacker_vis_msg.shot) + self.write_event(event_type=EventType.SHOT_ON_GOAL) diff --git a/src/software/evaluation/trackers/possession_tracker.py b/src/software/evaluation/trackers/possession_tracker.py new file mode 100644 index 0000000000..c7dcd3ede2 --- /dev/null +++ b/src/software/evaluation/trackers/possession_tracker.py @@ -0,0 +1,114 @@ +from software.evaluation.trackers.tracker import Tracker +from typing import override +from software.thunderscope.proto_unix_io import ProtoUnixIO +from proto.import_all_protos import * +import software.python_bindings as tbots_cpp +from software.evaluation.logs.event_log import EventType, Team +import queue +from software.py_constants import BALL_TO_FRONT_OF_ROBOT_DISTANCE_WHEN_DRIBBLING + + +class PossessionTracker(Tracker): + """Tracker to track and log when ball possession changes""" + + def __init__( + self, + proto_unix_io: ProtoUnixIO, + from_team: Team, + for_team: Team, + event_queue: queue.Queue, + **kwargs, + ): + """Initializes the PossessionTracker + + :param proto_unix_io: the proto unix io to get the game state from + :param from_team: the team that this tracker is tracking from (events are from this team) + :param for_team: the team that this tracker is tracking for (events are for this team) + default is same as the from_team, but can be different + :param event_queue: the queue to write events to + """ + super().__init__( + proto_unix_io=proto_unix_io, + from_team=from_team, + for_team=for_team, + event_queue=event_queue, + **kwargs, + ) + + # start with no team having possession + self.curr_possession = None + + @override + def refresh_tracker(self) -> None: + """Refresh and logs any changes in ball possession""" + if self.cached_world is None: + return + + self._log_posession_for_friendly( + self.cached_world.friendlyTeam(), + self.cached_world.enemyTeam(), + self.cached_world.ball().position(), + ) + + def _log_posession_for_friendly( + self, + friendly_team: tbots_cpp.Team, + enemy_team: tbots_cpp.Team, + ball_position: tbots_cpp.Point, + ) -> None: + """Detects possession changes and logs the corresponding start/end events + + Checks the current world state against the last known possession. + True if friendly possession, False is enemy possession, None if neither + + If a transition occurs (e.g., Friendly -> None, None -> Enemy), it triggers + a write_event with the appropriate EventType and updates the internal + possession state. + + :param friendly_team: the friendly team object + :param enemy_team: the enemy team object + :param ball_position: the current position of the ball + :return: None + """ + new_possession = None + + if self._check_posession_for_team(friendly_team, ball_position): + new_possession = True + elif self._check_posession_for_team(enemy_team, ball_position): + new_possession = False + + # if possession didn't change, no need to log + if new_possession == self.curr_possession: + return + + # mark the end of the last possession since it has changed + if self.curr_possession: + self.write_event(event_type=EventType.FRIENDLY_POSSESSION_END) + elif self.curr_possession == False: + self.write_event(event_type=EventType.ENEMY_POSSESSION_END) + + # log the start of the new, changed possession + if new_possession: + self.write_event(event_type=EventType.FRIENDLY_POSSESSION_START) + elif new_possession == False: + self.write_event(event_type=EventType.ENEMY_POSSESSION_START) + + self.curr_possession = new_possession + + def _check_posession_for_team( + self, team: tbots_cpp.Team, ball_position: tbots_cpp.Point + ) -> bool: + """Check if the given team has possession of the ball + + :param team: the team to check + :param ball_position: the current ball position + :return: True if the team has possession, False otherwise + """ + for robot in team.getAllRobots(): + # higher tolerance to make possession a bit stickier + if robot.isNearDribbler( + ball_position, BALL_TO_FRONT_OF_ROBOT_DISTANCE_WHEN_DRIBBLING * 2 + ): + return True + + return False diff --git a/src/software/evaluation/trackers/referee_tracker.py b/src/software/evaluation/trackers/referee_tracker.py new file mode 100644 index 0000000000..b9e3e8e96c --- /dev/null +++ b/src/software/evaluation/trackers/referee_tracker.py @@ -0,0 +1,138 @@ +from typing import override, Callable +from software.evaluation.trackers.tracker import Tracker +from proto.import_all_protos import * +from software.thunderscope.thread_safe_buffer import ThreadSafeBuffer +from software.thunderscope.proto_unix_io import ProtoUnixIO +from software.evaluation.logs.event_log import EventType, Team +import queue + + +class RefereeTracker(Tracker): + """Tracks Referee events, like goals and yellow / red cards for the friendly team only""" + + # we want to ignore all breaks, times before the game actually starts + # and all penalty related stages + STAGES_TO_IGNORE = [ + Referee.Stage.PENALTY_SHOOTOUT_BREAK, + Referee.Stage.PENALTY_SHOOTOUT, + Referee.Stage.NORMAL_FIRST_HALF_PRE, + Referee.Stage.NORMAL_SECOND_HALF_PRE, + Referee.Stage.EXTRA_TIME_BREAK, + Referee.Stage.EXTRA_FIRST_HALF_PRE, + Referee.Stage.EXTRA_SECOND_HALF_PRE, + ] + + def __init__( + self, + proto_unix_io: ProtoUnixIO, + from_team: Team, + for_team: Team, + event_queue: queue.Queue, + friendly_color_yellow: bool, + toggle_logging: Callable[[bool], None] | None = None, + **kwargs, + ): + """Initializes the RefereeTracker + + :param proto_unix_io: the proto unix io to get the game state from + :param from_team: the team that this tracker is tracking from (events are from this team) + :param for_team: the team that this tracker is tracking for (events are for this team) + default is same as the from_team, but can be different + :param event_queue: the queue to write events to + :param friendly_color_yellow: if the friendly color is yellow or blue + """ + super().__init__( + proto_unix_io=proto_unix_io, + from_team=from_team, + for_team=for_team, + event_queue=event_queue, + **kwargs, + ) + + self.referee_buffer = ThreadSafeBuffer(self.buffer_size, Referee) + self.proto_unix_io.register_observer(Referee, self.referee_buffer) + + self.friendly_color_yellow = friendly_color_yellow + + # we can use this callback to turn on / off logging + # during stages we don't care about + self.toggle_logging = toggle_logging + + self.num_yellow_cards = 0 + self.num_red_cards = 0 + self.num_goals = 0 + + self.curr_stage = None + + @override + def refresh_tracker(self) -> None: + """Refresh and log the latest referee information""" + referee_msg = self.referee_buffer.get(block=False, return_cached=True) + + if not referee_msg: + return + + game_stage = referee_msg.stage + + if game_stage in self.STAGES_TO_IGNORE: + if self.toggle_logging: + self.toggle_logging(False) + return + + if self.toggle_logging: + self.toggle_logging(True) + + # if game has just started, log a game start event once + if game_stage == Referee.Stage.NORMAL_FIRST_HALF: + self.curr_stage = self._log_event_if_change( + new_value=game_stage, + old_value=self.curr_stage, + event_type=EventType.GAME_START, + ) + return + + # if the game has just ended, log a game end event once + if game_stage == Referee.Stage.POST_GAME: + self.curr_stage = self._log_event_if_change( + new_value=game_stage, + old_value=self.curr_stage, + event_type=EventType.GAME_END, + ) + return + + self.curr_stage = game_stage + + if referee_msg.HasField("yellow" if self.friendly_color_yellow else "blue"): + team_info = ( + referee_msg.yellow if self.friendly_color_yellow else referee_msg.blue + ) + + if team_info.HasField("score"): + self.num_goals = self._log_event_if_change( + team_info.score, self.num_goals, EventType.GOAL_SCORED + ) + + if team_info.HasField("yellow_cards"): + self.num_yellow_cards = self._log_event_if_change( + team_info.yellow_cards, self.num_yellow_cards, EventType.YELLOW_CARD + ) + + if team_info.HasField("red_cards"): + self.num_red_cards = self._log_event_if_change( + team_info.red_cards, self.num_red_cards, EventType.RED_CARD + ) + + def _log_event_if_change( + self, new_value: int, old_value: int, event_type: EventType + ) -> int: + """Logs an event of the given type if the given value has changed between old and new + + :param new_value: the new value + :param old_value: the old value to compare with + :param event_type: the type of event to log if a change is detected + :return: the new value unchanged + """ + if new_value != old_value: + self.write_event(event_type=event_type) + + return new_value diff --git a/src/software/evaluation/trackers/tracker.py b/src/software/evaluation/trackers/tracker.py new file mode 100644 index 0000000000..4d69c41d4e --- /dev/null +++ b/src/software/evaluation/trackers/tracker.py @@ -0,0 +1,74 @@ +from software.thunderscope.proto_unix_io import ProtoUnixIO +from typing import Callable +from software.thunderscope.thread_safe_buffer import ThreadSafeBuffer +from proto.import_all_protos import * +import software.python_bindings as tbots_cpp +import queue +from software.evaluation.logs.event_log import EventType, Team, EventLog + + +class Tracker: + """Generic tracker base class. Just tracks the world state.""" + + def __init__( + self, + proto_unix_io: ProtoUnixIO, + from_team: Team, + event_queue: queue.Queue, + for_team: Team, + callback: Callable[[EventType], None] = None, + buffer_size: int = 5, + ): + """Initializes the tracker with the given callback and buffer size + + :param proto_unix_io: the proto unix io to get the game state from + :param from_team: the team that this tracker is tracking from (events are from this team) + :param for_team: the team that this tracker is tracking for (events are for this team) + default is same as the from_team, but can be different + :param event_queue: the queue to write events to + :param callback: optional callback to call when there's an event + :param buffer_size: buffer size for the tracker's io + """ + self.event_queue = event_queue + self.buffer_size = buffer_size + self.from_team = from_team + self.for_team = for_team + self.callback = callback + + self.proto_unix_io = proto_unix_io + self.world_buffer = ThreadSafeBuffer(self.buffer_size, World) + self.proto_unix_io.register_observer(World, self.world_buffer) + + self.cached_world = None + + def refresh(self) -> None: + """Refreshes the tracker to get the latest world message""" + world_msg = self.world_buffer.get(block=False, return_cached=True) + + if world_msg is None: + return + + self.cached_world_msg = world_msg + self.cached_world = tbots_cpp.World(world_msg) + + self.refresh_tracker() + + def refresh_tracker(self) -> None: + pass + + def write_event(self, event_type: EventType) -> None: + """Writes a single event to the event queue of the given type + + :param event_type: the type of event to log + """ + if not self.cached_world: + return + + event = EventLog.from_world( + world_msg=self.cached_world_msg, + event_type=event_type, + from_team=self.from_team, + for_team=self.for_team, + ) + + self.event_queue.put(event) diff --git a/src/software/evaluation/trackers/tracker_builder.py b/src/software/evaluation/trackers/tracker_builder.py new file mode 100644 index 0000000000..eaba7b16c2 --- /dev/null +++ b/src/software/evaluation/trackers/tracker_builder.py @@ -0,0 +1,60 @@ +from software.thunderscope.proto_unix_io import ProtoUnixIO +from software.evaluation.trackers.tracker import Tracker +from typing import Type, Self +from software.evaluation.logs.event_log import Team +import queue + + +class TrackerBuilder: + """Builder class to combine different trackers and update them together""" + + def __init__( + self, + proto_unix_io: ProtoUnixIO, + from_team: Team, + event_queue: queue.Queue, + for_team: Team | None = None, + buffer_size: int = 5, + ) -> None: + """Initializes the builder + + :param proto_unix_io: the unix io that the trackers should listen on + :param from_team: the team that this tracker is tracking from (events are from this team) + :param for_team: the team that this tracker is tracking for (events are for this team) + default is same as the from_team, but can be different + :param event_queue: the queue to write events to + """ + self.proto_unix_io = proto_unix_io + self.from_team = from_team + self.for_team = from_team if for_team is None else for_team + self.buffer_size = buffer_size + + self.event_queue = event_queue + + self.trackers = [] + + def add_tracker( + self, + tracker_cls: Type[Tracker], + **kwargs, + ) -> Self: + """Adds a single tracker to the list + + :param tracker_cls: The class of the tracker to instantiate + :param **kwargs: tracker-specific arguments + """ + tracker = tracker_cls( + proto_unix_io=self.proto_unix_io, + event_queue=self.event_queue, + from_team=self.from_team, + for_team=self.for_team, + buffer_size=self.buffer_size, + **kwargs, + ) + self.trackers.append(tracker) + return self + + def refresh(self) -> None: + """Refreshes all the trackers""" + for tracker in self.trackers: + tracker.refresh() diff --git a/src/software/py_constants.cpp b/src/software/py_constants.cpp index ce7b5f1f1a..4b6c68e0c9 100644 --- a/src/software/py_constants.cpp +++ b/src/software/py_constants.cpp @@ -40,6 +40,9 @@ PYBIND11_MODULE(py_constants, m) ACCELERATION_DUE_TO_GRAVITY_METERS_PER_SECOND_SQUARED; m.attr("ENEMY_BALL_PLACEMENT_DISTANCE_METERS") = ENEMY_BALL_PLACEMENT_DISTANCE_METERS; + m.attr("BALL_TO_FRONT_OF_ROBOT_DISTANCE_WHEN_DRIBBLING") = + BALL_TO_FRONT_OF_ROBOT_DISTANCE_WHEN_DRIBBLING; + m.attr("TACTIC_OVERRIDE_PATH") = TACTIC_OVERRIDE_PATH; m.attr("PLAY_OVERRIDE_PATH") = PLAY_OVERRIDE_PATH; diff --git a/src/software/thunderscope/BUILD b/src/software/thunderscope/BUILD index 769f9781f2..c9d685d643 100644 --- a/src/software/thunderscope/BUILD +++ b/src/software/thunderscope/BUILD @@ -19,11 +19,11 @@ py_binary( ":estop_helpers", ":thunderscope", ":util", + "//software/evaluation/loggers:stats_logger", "//software/thunderscope/binary_context_managers:full_system", "//software/thunderscope/binary_context_managers:game_controller", "//software/thunderscope/binary_context_managers:runtime_manager", "//software/thunderscope/binary_context_managers:simulator", - "//software/thunderscope/log/stats", ], ) diff --git a/src/software/thunderscope/constants.py b/src/software/thunderscope/constants.py index eb5bc0bb00..2c8d1a52cf 100644 --- a/src/software/thunderscope/constants.py +++ b/src/software/thunderscope/constants.py @@ -402,17 +402,8 @@ class RuntimeManagerConstants: EXTERNAL_RUNTIMES_PATH = "/opt/tbotspython/external_runtimes" RUNTIME_CONFIG_PATH = f"{EXTERNAL_RUNTIMES_PATH}/runtime_config.toml" - RUNTIME_STATS_DIRECTORY_PATH = "/tmp/tbots/stats" - RUNTIME_FRIENDLY_STATS_FILE = "blue.toml" - RUNTIME_ENEMY_FROM_FRIENDLY_STATS_FILE = "yellow_from_blue.toml" - RUNTIME_ENEMY_STATS_FILE = "yellow.toml" - RUNTIME_FRIENDLY_FROM_ENEMY_STATS_FILE = "blue_from_yellow.toml" - - RUNTIME_STATS_SCORE_KEY = "goals" - RUNTIME_STATS_RED_CARDS_KEY = "red_cards" - RUNTIME_STATS_YELLOW_CARDS_KEY = "yellow_cards" - RUNTIME_STATS_SHOTS_ON_NET = "shots_on_net" - RUNTIME_STATS_SHOTS_BLOCKED = "shots_blocked" + RUNTIME_EVENTS_DIRECTORY_PATH = "/tmp/tbots/stats" + RUNTIME_EVENTS_FILE = "game_events.csv" RELEASES_URL = "https://api.github.com/repos/UBC-Thunderbots/Software/releases" DOWNLOAD_URL = "https://github.com/UBC-Thunderbots/Software/releases/download/" diff --git a/src/software/thunderscope/log/stats/BUILD b/src/software/thunderscope/log/stats/BUILD deleted file mode 100644 index 20f800a11d..0000000000 --- a/src/software/thunderscope/log/stats/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "fullsystem_stats", - srcs = ["fullsystem_stats.py"], - deps = [ - "//software/thunderscope/log/trackers:tracker", - ], -) - -py_library( - name = "stats", - srcs = ["stats.py"], - deps = [ - ":fullsystem_stats", - "//software/thunderscope:thread_safe_buffer", - ], -) diff --git a/src/software/thunderscope/log/stats/fullsystem_stats.py b/src/software/thunderscope/log/stats/fullsystem_stats.py deleted file mode 100644 index 77049dee15..0000000000 --- a/src/software/thunderscope/log/stats/fullsystem_stats.py +++ /dev/null @@ -1,239 +0,0 @@ -import os - -from software.thunderscope.log.trackers import ( - PossessionTracker, - ShotTracker, - TrackerBuilder, - RefereeTracker, - GoalieTracker, -) -from dataclasses import dataclass -from software.thunderscope.proto_unix_io import ProtoUnixIO -from software.thunderscope.constants import RuntimeManagerConstants -import logging -from proto.import_all_protos import * -from rich import print - - -@dataclass -class FSStats: - """Stats for how well a FullSystem is performing""" - - num_yellow_cards: int = 0 - num_red_cards: int = 0 - num_scores: int = 0 - - num_shots_on_net: int = 0 - num_enemy_shots_blocked: int = 0 - - -class FullSystemStats: - # From GoalieTacticConfig - INCOMING_SHOT_MIN_VELOCITY = 0.2 - - def __init__( - self, - proto_unix_io: ProtoUnixIO, - friendly_colour_yellow: bool, - buffer_size: int = 5, - record_enemy_stats: bool = False, - ): - """Initializes the FullSystem Stats Tracker - - :param friendly_colour_yellow: if the friendly colour is yellow - :param buffer_size: the buffer size for protocol buffers - :param record_enemy_stats: if this should record both friendly and enemy stats or just friendly - """ - self.friendly_colour_yellow = friendly_colour_yellow - - # True if friendly had the last possession, False if enemy - # None if neither - self.last_possession_friendly: bool | None = None - - # use both trackers to keep track of shots on net - # take the min to avoid noise from the attacker tactic visualization - self.num_shots_on_net_attacker: int = 0 - self.num_shots_on_net_goalie: int = 0 - - self.stats = FSStats() - - # these should be set up using the setup method - self.stats_file = None - self.enemy_stats_file = None - - # the python __del__ destructor isn't called reliably - # so printing this at the start instead - print(f"[bold red]Writing FS Stats to {self._get_stats_file()}") - - self.tracker = ( - TrackerBuilder(proto_unix_io=proto_unix_io) - .add_tracker( - ShotTracker, callback=self._update_shot_count, buffer_size=buffer_size - ) - .add_tracker( - PossessionTracker, - callback=self._update_posession, - buffer_size=buffer_size, - ) - .add_tracker( - RefereeTracker, - callback=self._update_referee_info_friendly, - friendly_color_yellow=self.friendly_colour_yellow, - buffer_size=buffer_size, - ) - .add_tracker( - GoalieTracker, - callback=self._update_goalie_shot_friendly, - for_friendly=True, - buffer_size=buffer_size, - ) - ) - - self.record_enemy_stats = record_enemy_stats - if self.record_enemy_stats: - self.enemy_stats = FSStats() - self.tracker = self.tracker.add_tracker( - RefereeTracker, - callback=self._update_referee_info_enemy, - friendly_color_yellow=(not self.friendly_colour_yellow), - ).add_tracker( - GoalieTracker, - callback=self._update_goalie_shot_enemy, - for_friendly=False, - buffer_size=buffer_size, - ) - - print(f"[bold red]Writing Enemy FS Stats to {self._get_enemy_stats_file()}") - - def refresh(self) -> None: - """Refreshes the stats for the game so far""" - self.tracker.refresh() - - self._flush_stats() - - def _update_shot_count(self, _: Shot): - self.stats.num_shots_on_net += 1 - - def _update_posession(self, friendly_posession: bool | None): - self.last_possession_friendly = not friendly_posession - - def _update_referee_info_friendly( - self, num_goals: int, num_yellow_cards: int, num_red_cards: int - ) -> None: - # the callback for tracking incoming shots can't differentiate between a blocked shot vs a goal - # so we subtract all "blocked" shots that were actually goals, so not blocked - if self.stats.num_scores < num_goals: - self.stats.num_enemy_shots_blocked -= 1 - - self.stats.num_scores = num_goals - self.stats.num_yellow_cards = num_yellow_cards - self.stats.num_red_cards = num_red_cards - - def _update_referee_info_enemy( - self, num_goals: int, num_yellow_cards: int, num_red_cards: int - ) -> None: - # the callback for tracking incoming shots can't differentiate between a blocked shot vs a goal - # so we subtract all "blocked" shots that were actually goals, so not blocked - if self.enemy_stats.num_scores < num_goals: - self.enemy_stats.num_enemy_shots_blocked -= 1 - - self.enemy_stats.num_scores = num_goals - self.enemy_stats.num_yellow_cards = num_yellow_cards - self.enemy_stats.num_red_cards = num_red_cards - - def _update_goalie_shot_friendly( - self, is_shot_incoming: bool, last_shot_incoming: bool - ) -> None: - if not is_shot_incoming and last_shot_incoming: - self.stats.num_enemy_shots_blocked += 1 - - if self.record_enemy_stats: - if is_shot_incoming and not last_shot_incoming: - self.enemy_stats.num_shots_on_net += 1 - - def _update_goalie_shot_enemy( - self, is_shot_incoming: bool, last_shot_incoming: bool - ) -> None: - if self.record_enemy_stats: - if not is_shot_incoming and last_shot_incoming: - self.enemy_stats.num_enemy_shots_blocked += 1 - - def _get_stats_file(self): - return os.path.join( - RuntimeManagerConstants.RUNTIME_STATS_DIRECTORY_PATH, - RuntimeManagerConstants.RUNTIME_ENEMY_STATS_FILE - if self.friendly_colour_yellow - else RuntimeManagerConstants.RUNTIME_FRIENDLY_STATS_FILE, - ) - - def _get_enemy_stats_file(self): - return os.path.join( - RuntimeManagerConstants.RUNTIME_STATS_DIRECTORY_PATH, - RuntimeManagerConstants.RUNTIME_FRIENDLY_FROM_ENEMY_STATS_FILE - if self.friendly_colour_yellow - else RuntimeManagerConstants.RUNTIME_ENEMY_FROM_FRIENDLY_STATS_FILE, - ) - - def setup(self): - """Sets up the file resources for logging - Creates any missing directories and stores the file handle - """ - stats_file_name = self._get_stats_file() - - # create temp stats directory if it doesn't exist - os.makedirs(os.path.dirname(stats_file_name), exist_ok=True) - - self.stats_file = open(stats_file_name, "w") - - if self.record_enemy_stats: - enemy_stats_file_name = self._get_enemy_stats_file() - - # create temp stats directory if it doesn't exist - os.makedirs(os.path.dirname(enemy_stats_file_name), exist_ok=True) - - self.enemy_stats_file = open(enemy_stats_file_name, "w") - - def cleanup(self): - """Writes all logs back to file, and cleans up any created file resources after logging""" - self._flush_stats() - - if self.stats_file: - self.stats_file.flush() - self.stats_file.close() - - if self.record_enemy_stats and self.enemy_stats_file: - self.enemy_stats_file.flush() - self.enemy_stats_file.close() - - def _flush_stats(self): - """Write the current stats to disk""" - self._write_stats_to_file(self.stats, self.stats_file) - - if self.record_enemy_stats: - self._write_stats_to_file(self.enemy_stats, self.enemy_stats_file) - - def _write_stats_to_file(self, stats: FSStats, stats_file) -> None: - """Write the given stats to the given file - - :param stats: the stats to write - :param stats_file: handle to the file to write to - """ - if not stats_file: - return - - try: - # formatted as key-value pairs in TOML - stats_to_write = ( - f'{RuntimeManagerConstants.RUNTIME_STATS_SCORE_KEY} = "{stats.num_scores}"\n' - f'{RuntimeManagerConstants.RUNTIME_STATS_RED_CARDS_KEY} = "{stats.num_red_cards}"\n' - f'{RuntimeManagerConstants.RUNTIME_STATS_YELLOW_CARDS_KEY} = "{stats.num_yellow_cards}"\n' - f'{RuntimeManagerConstants.RUNTIME_STATS_SHOTS_ON_NET} = "{stats.num_shots_on_net}"\n' - f'{RuntimeManagerConstants.RUNTIME_STATS_SHOTS_BLOCKED} = "{stats.num_enemy_shots_blocked}"' - ) - - stats_file.seek(0) - stats_file.write(stats_to_write) - stats_file.truncate() - - except (FileNotFoundError, PermissionError): - logging.warning("Failed to write TOML FS stats file") diff --git a/src/software/thunderscope/log/stats/stats.py b/src/software/thunderscope/log/stats/stats.py deleted file mode 100644 index de7e9074fc..0000000000 --- a/src/software/thunderscope/log/stats/stats.py +++ /dev/null @@ -1,33 +0,0 @@ -from software.thunderscope.log.stats.fullsystem_stats import FullSystemStats -from software.thunderscope.proto_unix_io import ProtoUnixIO -from proto.import_all_protos import * - - -class Stats: - """This class is a wrapper for all Statistics related operations we want to do with FullSystem or Thunderscope""" - - def __init__( - self, - proto_unix_io: ProtoUnixIO, - friendly_color_yellow: bool = False, - record_enemy_stats: bool = False, - buffer_size: int = 5, - ): - self.proto_unix_io = proto_unix_io - - self.fs_stats = FullSystemStats( - friendly_colour_yellow=friendly_color_yellow, - proto_unix_io=proto_unix_io, - buffer_size=buffer_size, - record_enemy_stats=record_enemy_stats, - ) - - def refresh(self): - self.fs_stats.refresh() - - def __enter__(self): - self.fs_stats.setup() - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.fs_stats.cleanup() diff --git a/src/software/thunderscope/log/trackers/__init__.py b/src/software/thunderscope/log/trackers/__init__.py deleted file mode 100644 index 75eb8e1915..0000000000 --- a/src/software/thunderscope/log/trackers/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from software.thunderscope.log.trackers.kick_tracker import ShotTracker, PassTracker -from software.thunderscope.log.trackers.possession_tracker import PossessionTracker -from software.thunderscope.log.trackers.tracker_builder import TrackerBuilder -from software.thunderscope.log.trackers.referee_tracker import RefereeTracker -from software.thunderscope.log.trackers.goalie_tracker import GoalieTracker - -from software.thunderscope.log.trackers.tracked_event import EventType, TrackedEvent - -__all__ = [ - "PossessionTracker", - "ShotTracker", - "PassTracker", - "TrackerBuilder", - "RefereeTracker", - "GoalieTracker", - "TrackedEvent", - "EventType", -] diff --git a/src/software/thunderscope/log/trackers/possession_tracker.py b/src/software/thunderscope/log/trackers/possession_tracker.py deleted file mode 100644 index 5dd3a7b4a8..0000000000 --- a/src/software/thunderscope/log/trackers/possession_tracker.py +++ /dev/null @@ -1,87 +0,0 @@ -from software.thunderscope.log.trackers.tracker import Tracker -from typing import override, Callable -from software.thunderscope.thread_safe_buffer import ThreadSafeBuffer -from software.thunderscope.proto_unix_io import ProtoUnixIO -from proto.import_all_protos import * -import software.python_bindings as tbots_cpp - - -class PossessionTracker(Tracker): - """Tracker to track when ball possession changes""" - - def __init__(self, callback: Callable[[bool | None], None], buffer_size: int = 5): - """Initializes the Possession tracker - - :param callback: function to call when possession changes - called with an optional bool value: - - True: friendly possession - - False: enemy possession - - None: neither - :param buffer_size: buffer size for the tracker's io - """ - super().__init__(callback=callback, buffer_size=buffer_size) - - self.world_buffer = ThreadSafeBuffer(self.buffer_size, World) - - @override - def set_proto_unix_io(self, proto_unix_io: ProtoUnixIO) -> None: - super().set_proto_unix_io( - proto_unix_io, - [ - (World, self.world_buffer), - ], - ) - - @override - def refresh(self): - """Refresh and update the callback with the latest ball possession""" - world_msg = self.world_buffer.get(block=False, return_cached=True) - - if world_msg is None: - return - - world = tbots_cpp.World(world_msg) - - if self.callback: - self.callback( - self._check_posession_for_friendly( - world.friendlyTeam(), world.enemyTeam(), world.ball().position() - ) - ) - - def _check_posession_for_friendly( - self, - friendly_team: tbots_cpp.Team, - enemy_team: tbots_cpp.Team, - ball_position: tbots_cpp.Point, - ) -> bool | None: - """Check for if the friendly team has possession of the ball - True if they do, False if enemy team has possession, None if neither - - :param friendly_team: the friendly team - :param enemy_team: the enemy team - :param ball_position: the current ball position - :return: True / False / None depending on which team has possession - """ - if self._check_posession_for_team(friendly_team, ball_position): - return True - - if self._check_posession_for_team(enemy_team, ball_position): - return False - - return None - - def _check_posession_for_team( - self, team: tbots_cpp.Team, ball_position: tbots_cpp.Point - ) -> bool: - """Check if the given team has possession of the ball - - :param team: the team to check - :param ball_position: the current ball position - :return: True if the team has possession, False otherwise - """ - for robot in team.getAllRobots(): - if robot.isNearDribbler(ball_position): - return True - - return False diff --git a/src/software/thunderscope/log/trackers/referee_tracker.py b/src/software/thunderscope/log/trackers/referee_tracker.py deleted file mode 100644 index b7d824d08b..0000000000 --- a/src/software/thunderscope/log/trackers/referee_tracker.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import override, Callable -from software.thunderscope.log.trackers.tracker import Tracker -from proto.import_all_protos import * -from software.thunderscope.thread_safe_buffer import ThreadSafeBuffer -from software.thunderscope.proto_unix_io import ProtoUnixIO - - -class RefereeTracker(Tracker): - """Tracks Referee events, like goals and yellow / red cards for the friendly team only""" - - def __init__( - self, - friendly_color_yellow: bool, - callback: Callable[[int, int, int], None], - buffer_size: int = 5, - ): - """Initializes the Referee tracker - - :param friendly_color_yellow: if the friendly color is yellow or not - determines which goals, etc. the tracker tracks - :param callback: function to call when there is any new Referee event - called with the current goals, yellow cards, and red cards - :param buffer_size: buffer size for the tracker's io - """ - super().__init__(callback=callback, buffer_size=buffer_size) - - self.referee_buffer = ThreadSafeBuffer(buffer_size, Referee) - - self.friendly_color_yellow = friendly_color_yellow - - @override - def set_proto_unix_io(self, proto_unix_io: ProtoUnixIO) -> None: - super().set_proto_unix_io( - proto_unix_io, - [ - (Referee, self.referee_buffer), - ], - ) - - @override - def refresh(self): - """Refresh and update the callback with the latest referee information""" - refree_msg = self.referee_buffer.get(block=False, return_cached=True) - - if not refree_msg: - return - - if refree_msg.HasField("yellow" if self.friendly_color_yellow else "blue"): - team_info = ( - refree_msg.yellow if self.friendly_color_yellow else refree_msg.blue - ) - - num_goals = 0 - num_yellow_cards = 0 - num_red_cards = 0 - - if team_info.HasField("score"): - num_goals = team_info.score - - if team_info.HasField("yellow_cards"): - num_yellow_cards = team_info.yellow_cards - - if team_info.HasField("red_cards"): - num_red_cards = team_info.red_cards - - if self.callback: - self.callback(num_goals, num_yellow_cards, num_red_cards) diff --git a/src/software/thunderscope/log/trackers/tracked_event.py b/src/software/thunderscope/log/trackers/tracked_event.py deleted file mode 100644 index 0ae6a22a3a..0000000000 --- a/src/software/thunderscope/log/trackers/tracked_event.py +++ /dev/null @@ -1,162 +0,0 @@ -from dataclasses import dataclass -from enum import StrEnum, auto -from proto.import_all_protos import * -from typing import Any -from software.py_constants import DIV_B_NUM_ROBOTS -from google.protobuf.descriptor import Descriptor, FieldDescriptor -from software.thunderscope.time_provider import time_provider_instance - - -def count_primitive_fields(descriptor: Descriptor): - """Recursively counts the number of primitive fields in a Protobuf message - using its descriptor. - - :param message: the message descriptor to count all leaf-level primitive fields for - :return: the count of primitive fields - """ - count = 0 - - for field in descriptor.fields: - # Check if the field is a nested message - if field.type == FieldDescriptor.TYPE_MESSAGE: - # Get the nested message class to recurse into its descriptor - nested_message = field.message_type - # Recurse using the nested message's descriptor - count += count_primitive_fields(nested_message) - else: - # It's a primitive type (double, float, int, bool, string, etc.) - count += 1 - return count - - -NUM_ROBOT_FIELDS = count_primitive_fields(RobotState.DESCRIPTOR) - - -class EventType(StrEnum): - """Enum for the different types of events we want to track""" - - PASS = auto() - SHOT_ON_GOAL = auto() - ENEMY_SHOT_ON_GOAL = auto() - SHOT_BLOCKED = auto() - FRIENDLY_POSSESSION_START = auto() - FRIENDLY_POSSESSION_END = auto() - ENEMY_POSSESSION_START = auto() - ENEMY_POSSESSION_END = auto() - GAME_START = auto() - GAME_END = auto() - GOAL_SCORED = auto() - YELLOW_CARD = auto() - RED_CARD = auto() - - -class Team(StrEnum): - """The teams present in the game""" - - BLUE = auto() - YELLOW = auto() - - -@dataclass -class Robot: - """Represents a single robot on the field, with ID and current state.""" - - id: int - state: RobotState - - -@dataclass -class TrackedEvent: - """Represents a single event being tracked, where and for whom the event is, and the game state at the time of the event""" - - event_type: EventType - timestamp: float - from_team: Team - for_team: Team - ball_state: BallState - friendly_robots: list[Robot] - enemy_robots: list[Robot] - - -def get_event_from_world( - world_msg: World, event_type: EventType, from_team: Team, for_team: Team -) -> TrackedEvent: - """Creates a TrackedEvent from a world protobuf message - - :param world_msg: the world object containing the state of the game - :param event_type: the type of event being recorded - :param from_team: the team that the event is coming from - :param for_team: the team that the event is for - :return: a fully populated TrackedEvent including ball and robot states - """ - ball_state = world_msg.ball.current_state - - friendly_robots = [ - Robot(id=robot.id, state=robot.current_state) - for robot in world_msg.friendly_team.team_robots - ] - enemy_robots = [ - Robot(id=robot.id, state=robot.current_state) - for robot in world_msg.enemy_team.team_robots - ] - - return TrackedEvent( - timestamp=time_provider_instance.elapsed_time_ns(), - event_type=event_type, - from_team=from_team, - for_team=for_team, - ball_state=ball_state, - friendly_robots=friendly_robots, - enemy_robots=enemy_robots, - ) - - -def add_robots_to_row(row: list[Any], robots: list[Robot]) -> None: - """Serializes robots into flattened columns within a list row - - :param row: the existing list representing a CSV row to be appended to - :param robot_states: the list of RobotState objects to flatten and add - :return: None (the row list is modified in place) - """ - # robot state columns will be added based on robot id - robot_state_map = {robot.id: robot.state for robot in robots} - - # Add friendly robots: [r1_data, r2_data...] based on id - for idx in range(DIV_B_NUM_ROBOTS): - if idx not in robot_state_map: - row.extend([None] * NUM_ROBOT_FIELDS) - else: - robot_state = robot_state_map[idx] - robot_row = [ - robot_state.global_position.x_meters, - robot_state.global_position.y_meters, - robot_state.global_orientation.radians, - robot_state.global_velocity.x_component_meters, - robot_state.global_velocity.y_component_meters, - robot_state.global_angular_velocity.radians_per_second, - ] - - assert len(robot_row) == NUM_ROBOT_FIELDS - - row.extend(robot_row) - - -def event_to_csv_row(event: TrackedEvent) -> str: - """Serializes a TrackedEvent into a flat CSV string row - - :param event: the TrackedEvent object to convert - :return: a comma-separated string of all event attributes and robot states - """ - row = [event.event_type.value, event.timestamp] - - row = row + [ - event.ball_state.global_position.x_meters, - event.ball_state.global_position.y_meters, - event.ball_state.global_velocity.x_component_meters, - event.ball_state.global_velocity.y_component_meters, - ] - - add_robots_to_row(row, event.friendly_robots) - add_robots_to_row(row, event.enemy_robots) - - return ",".join([str(elem) for elem in row]) diff --git a/src/software/thunderscope/log/trackers/tracker.py b/src/software/thunderscope/log/trackers/tracker.py deleted file mode 100644 index f5f4aa319a..0000000000 --- a/src/software/thunderscope/log/trackers/tracker.py +++ /dev/null @@ -1,36 +0,0 @@ -from software.thunderscope.proto_unix_io import ProtoUnixIO -from typing import Callable, Optional, Tuple, Type -from software.thunderscope.thread_safe_buffer import ThreadSafeBuffer -from google.protobuf.message import Message - - -class Tracker: - """Generic tracker base class.""" - - def __init__( - self, callback: Optional[Callable[..., None]] = None, buffer_size: int = 5 - ): - """Initializes the tracker with the given callback and buffer size - - :param callback: the function to call when the tracker tracks an event - :param buffer_size: buffer size for the tracker's io - """ - self.callback = callback - self.buffer_size = buffer_size - - def set_proto_unix_io( - self, - proto_unix_io: ProtoUnixIO, - type_buffers: list[Tuple[Type[Message], ThreadSafeBuffer]], - ) -> None: - """Registers the given message types and buffers to the given proto unix io connection - - :param proto_unix_io: the io connection to listen on - :param type_buffers: a list of (Message Type, Buffer) tuples. - messages of each type will be placed into their corresponding buffer - """ - for message_type, buffer in type_buffers: - proto_unix_io.register_observer(message_type, buffer) - - def refresh(self) -> None: - raise Exception("Not Implemented, please use the appropriate subclass!") diff --git a/src/software/thunderscope/log/trackers/tracker_builder.py b/src/software/thunderscope/log/trackers/tracker_builder.py deleted file mode 100644 index 903458b6e7..0000000000 --- a/src/software/thunderscope/log/trackers/tracker_builder.py +++ /dev/null @@ -1,38 +0,0 @@ -from software.thunderscope.proto_unix_io import ProtoUnixIO -from software.thunderscope.log.trackers.tracker import Tracker -from typing import Callable, Any, Optional, Type, Self - - -class TrackerBuilder: - """Builder class to combine different trackers and update them together""" - - def __init__(self, proto_unix_io: ProtoUnixIO) -> None: - """Initializes the builder - - :param proto_unix_io: the unix io that the trackers should listen on - """ - self.proto_unix_io = proto_unix_io - - self.trackers = [] - - def add_tracker( - self, - tracker_cls: Type[Tracker], - callback: Optional[Callable[[Any], None]] = None, - **kwargs, - ) -> Self: - """Adds a single tracker to the list - - :param tracker_cls: The class of the tracker to instantiate - :param callback: function that the tracker should call when it tracks an event - :param **kwargs: tracker-specific arguments - """ - tracker = tracker_cls(callback=callback, **kwargs) - tracker.set_proto_unix_io(self.proto_unix_io) - self.trackers.append(tracker) - return self - - def refresh(self) -> None: - """Refreshes all the trackers""" - for tracker in self.trackers: - tracker.refresh() diff --git a/src/software/thunderscope/thunderscope_main.py b/src/software/thunderscope/thunderscope_main.py index c35b675f52..edee6685ef 100644 --- a/src/software/thunderscope/thunderscope_main.py +++ b/src/software/thunderscope/thunderscope_main.py @@ -11,7 +11,7 @@ from software.thunderscope.binary_context_managers.runtime_manager import ( runtime_manager_instance, ) -from software.thunderscope.log.stats.stats import Stats +from software.evaluation.loggers.stats_logger import StatsLogger from software.thunderscope.thunderscope import Thunderscope from software.thunderscope.constants import LogLevels @@ -499,17 +499,21 @@ def __ticker(tick_rate_ms: int) -> None: if args.enable_autoref else contextlib.nullcontext() ) as autoref, ( - Stats( + StatsLogger( proto_unix_io=tscope.proto_unix_io_map[ProtoUnixIOTypes.BLUE], record_enemy_stats=True, + friendly_colour_yellow=False, ) if args.record_stats else contextlib.nullcontext() - ) as blue_stats, ( - Stats(proto_unix_io=tscope.proto_unix_io_map[ProtoUnixIOTypes.YELLOW]) + ) as blue_stats_logger, ( + StatsLogger( + proto_unix_io=tscope.proto_unix_io_map[ProtoUnixIOTypes.YELLOW], + friendly_colour_yellow=True, + ) if args.record_stats else contextlib.nullcontext() - ) as yellow_stats: + ) as yellow_stats_logger: tscope.register_refresh_function(gamecontroller.refresh) autoref_proto_unix_io = ProtoUnixIO() @@ -520,9 +524,9 @@ def __ticker(tick_rate_ms: int) -> None: tscope.proto_unix_io_map[ProtoUnixIOTypes.YELLOW] ) - if args.record_stats: - tscope.register_refresh_function(blue_stats.refresh) - tscope.register_refresh_function(yellow_stats.refresh) + if args.record_stats and blue_stats_logger and yellow_stats_logger: + tscope.register_refresh_function(blue_stats_logger.refresh) + tscope.register_refresh_function(yellow_stats_logger.refresh) simulator.setup_proto_unix_io( tscope.proto_unix_io_map[ProtoUnixIOTypes.SIM],