diff --git a/benchmarks/aeos.py b/benchmarks/aeos.py new file mode 100644 index 00000000..d9b96d52 --- /dev/null +++ b/benchmarks/aeos.py @@ -0,0 +1,158 @@ +import numpy as np +from Basilisk.utilities import orbitalMotion +from benchmark import BenchmarkEnv +from ray.rllib.core.rl_module.rl_module import RLModuleSpec + +from bsk_rl import act, comm, data, obs, sats, scene +from bsk_rl.sim import fsw +from bsk_rl.utils.orbital import random_circular_orbit, walker_delta_args + + +class AEOS(sats.ImagingSatellite): + action_spec = [act.Image(n_ahead_image=32)] + observation_spec = [ + obs.SatProperties( + dict(prop="omega_BH_H", norm=0.03), + dict(prop="c_hat_H"), + dict(prop="r_BN_P", norm=orbitalMotion.REQ_EARTH * 1e3), + dict(prop="v_BN_P", norm=7616.5), + ), + obs.OpportunityProperties( + dict(prop="priority"), + dict(prop="r_LB_H", norm=800 * 1e3), + dict(prop="target_angle", norm=np.pi / 2), + dict(prop="target_angle_rate", norm=0.03), + dict(prop="opportunity_open", norm=300.0), + dict(prop="opportunity_close", norm=300.0), + type="target", + n_ahead_observe=32, + ), + obs.Time(), + ] + + fsw_type = (fsw.SteeringFSWModel, fsw.ImagingFSWModel) + + +SAT_ARGS = dict( + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + batteryStorageCapacity=80.0 * 3600 * 100, + storedCharge_Init=80.0 * 3600 * 100.0, + dataStorageCapacity=200 * 8e6 * 100, + u_max=0.4, + imageTargetMinimumElevation=np.arctan(800 / 500), + K1=0.25, + K3=3.0, + omega_max=np.radians(5), + servo_Ki=5.0, + servo_P=150 / 5, + oe=lambda: random_circular_orbit(alt=800, i=45), +) + + +def episode_data_callback(env): + data = {} + + imaged = env.rewarder.data.imaged + reward = sum(env.rewarder.cum_reward.values()) + + data["imaged"] = len(imaged) + data["reward"] = reward + data["avg_tgt_val"] = reward / len(imaged) if len(imaged) > 0 else 0.0 + data["duplicates"] = env.rewarder.data.duplicates + + return data + + +def satellite_data_callback(env, satellite): + data = {} + + imaged = satellite.imaged + missed = satellite.missed + + reward = env.rewarder.cum_reward[satellite.name] + + duration = max(env.simulator.sim_time, 0.01) + orbits = duration / (95 * 60) + + data["imaged"] = imaged + data["imaged_per_orbit"] = imaged / orbits + data["missed"] = missed + data["missed_per_orbit"] = missed / orbits + data["reward"] = reward + data["reward_per_orbit"] = reward / orbits + if imaged == 0: + data["avg_tgt_val"] = 0 + data["success_rate"] = 0 + else: + data["avg_tgt_val"] = reward / imaged + data["success_rate"] = imaged / (imaged + missed) + data["attempts"] = imaged + missed + data["attempts_per_orbit"] = (imaged + missed) / orbits + + data["orbits_completed"] = orbits + data["alive"] = float(satellite.is_alive()) + + return data + + +def gen_env_args(n_satellites=1): + env_args = dict( + satellites=[ + AEOS(name=f"EO{i + 1}", sat_args=SAT_ARGS) for i in range(n_satellites) + ], + scenario=scene.UniformTargets((100, 10000)), + rewarder=data.UniqueImageReward(), + communicator=comm.FreeCommunication(min_period=60), + sim_rate=0.5, + max_step_duration=300.0, + time_limit=5700 * 3, + failure_penalty=0.0, + terminate_on_time_limit=True, + # episode_data_callback=episode_data_callback, + # satellite_data_callback=satellite_data_callback, + ) + if n_satellites > 1: + env_args["sat_arg_randomizer"] = walker_delta_args( + n_planes=1, altitude=800, inc=45 + ) + return env_args + + +policies = {"policy"} +policy_mapping_fn = lambda agent_id, *args, **kwargs: "policy" +module_specs = { + "policy": RLModuleSpec( + model_config_dict={ + "use_lstm": False, + "fcnet_hiddens": [1024, 1024], + "vf_share_layers": False, + }, + ), +} + +training_args = dict( + lr=3e-5, + gamma=0.997, + train_batch_size=3000, + num_sgd_iter=10, + use_kl_loss=False, + clip_param=0.2, + grad_clip=0.5, +) + +aeos_single = BenchmarkEnv( + env_args=gen_env_args(n_satellites=1), + policies=policies, + policy_mapping_fn=policy_mapping_fn, + module_specs=module_specs, + training_args=training_args, +) + +aeos_constellation = BenchmarkEnv( + env_args=gen_env_args(n_satellites=3), + policies=policies, + policy_mapping_fn=policy_mapping_fn, + module_specs=module_specs, + training_args=training_args, +) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py new file mode 100644 index 00000000..9bb62763 --- /dev/null +++ b/benchmarks/benchmark.py @@ -0,0 +1,387 @@ +import multiprocessing as mp +import os +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Callable + +import numpy as np +import ray +import torch +from Basilisk.architecture import bskLogging +from ray.rllib.algorithms.ppo import PPO, PPOConfig +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.tune.logger import UnifiedLogger + +from bsk_rl.utils.rllib.callbacks import WrappedEpisodeDataCallbacks +from bsk_rl.utils.rllib.discounting import ( + CondenseMultiStepActions, + ContinuePreviousAction, + ContinuePreviousActionAppended, + MakeAddedStepActionValid, + TimeDiscountedGAEPPOTorchLearner, +) + +bskLogging.setDefaultLogLevel(bskLogging.BSK_WARNING) + + +@dataclass +class BenchmarkEnv: + env_args: dict + policies: set + policy_mapping_fn: Callable + module_specs: dict + training_args: dict + + +# TODO remove, for cluster only +torch.set_num_threads(11) +os.environ["MKL_NUM_THREADS"] = "11" + + +def get_available_cores(): + """Returns the number of available CPU cores, accounting for SLURM environment variables if present.""" + try: + processes = int(os.environ["SLURM_JOB_CPUS_PER_NODE"].split("(")[0]) + except Exception: + processes = mp.cpu_count() + return processes + + +def create_new_model( + output_directory, + env_args, + policies, + policy_mapping_fn, + module_specs, + num_env_runners=1, + training_args={}, + temp_dir="/tmp", +): + """Configure a PPO model for training with sMDP discounting and asynchronous multiagent actions.""" + + os.environ["RAY_TMPDIR"] = os.environ["TMPDIR"] = temp_dir + output_directory = Path(output_directory) + output_directory.mkdir(exist_ok=True, parents=True) + + multi_agent = len(env_args["satellites"]) > 1 + + ray.init( + ignore_reinit_error=True, + num_cpus=get_available_cores(), + object_store_memory=2_000_000_000, # 2 GB + _temp_dir=temp_dir, + ) + + # Add connectors for async multi-agent actions if needed + module_to_env_connector = dict() + if multi_agent: + module_to_env_connector = dict( + module_to_env_connector=lambda env: (ContinuePreviousAction(),) + ) + + config = ( + PPOConfig() + .training(**training_args) + .env_runners( + num_env_runners=num_env_runners, + sample_timeout_s=50000.0, + **module_to_env_connector, + ) + .environment( + env="ConstellationTasking-RLlib", + env_config=env_args, + ) + .callbacks(WrappedEpisodeDataCallbacks) + .reporting( + metrics_num_episodes_for_smoothing=1, + metrics_episode_collection_timeout_s=180, + ) + .checkpointing(export_native_model_files=True) + .framework(framework="torch") + .api_stack( + enable_rl_module_and_learner=True, + enable_env_runner_and_connector_v2=True, + ) + .resources(num_gpus=0) + .multi_agent( + policies=policies, + policy_mapping_fn=policy_mapping_fn, + ) + .rl_module( + rl_module_spec=MultiRLModuleSpec( + module_specs=module_specs, + ), + ) + ) + + # Add connectors for async multi-agent actions if needed + learning_connector = dict() + if multi_agent: + learning_connector = dict( + learner_connector=lambda obs_space, act_space: ( + MakeAddedStepActionValid( + expected_train_batch_size=config.train_batch_size + ), + CondenseMultiStepActions(), + ), + ) + + config.training( + **training_args, + **learning_connector, + learner_class=TimeDiscountedGAEPPOTorchLearner, + learner_config_dict=dict(reward_time="step_end"), + ) + config.logger_config = dict(type=UnifiedLogger, logdir=output_directory) + + # Add connectors for async multi-agent actions if needed + if multi_agent: + old_connector_builder = config.build_module_to_env_connector + + def new_connector_builder(env): + pipeline = old_connector_builder(env) + pipeline.insert_after( + "NormalizeAndClipActions", ContinuePreviousActionAppended() + ) + return pipeline + + config.build_module_to_env_connector = new_connector_builder + + ppo = PPO(config) + + return ppo + + +def find_latest_checkpoint(output_directory): + checkpoint_numbers = [] + for folder_name in os.listdir(output_directory): + if folder_name.startswith("checkpoint_"): + num_str = folder_name.split("_")[1] + if num_str.isdigit(): + checkpoint_numbers.append(int(num_str)) + + if not checkpoint_numbers: + print(f"No checkpoints found in {output_directory}") + return 0 + + return max(checkpoint_numbers) + + +def load_existing_model( + output_directory, + temp_dir="/tmp", +): + os.environ["RAY_TMPDIR"] = os.environ["TMPDIR"] = temp_dir + + iter = find_latest_checkpoint(output_directory) + print(f"Starting training from iteration {iter}") + + ray.init( + ignore_reinit_error=True, + num_cpus=get_available_cores(), + object_store_memory=2_000_000_000, # 2 GB + _temp_dir=temp_dir, + ) + checkpoint_path = output_directory / f"checkpoint_{str(iter).zfill(6)}" + ppo = PPO.from_checkpoint(checkpoint_path) + + return ppo + + +def train( + ppo: PPO, + output_directory: Path, + checkpoint_frequency=1, + checkpoints_to_keep=2, + total_timesteps=1_000_000, +): + iter = find_latest_checkpoint(output_directory) + step = 0 + + # Track the best return + current_best_return = -np.inf + if (output_directory / "checkpoint_best" / "return.txt").exists(): + with open(output_directory / "checkpoint_best" / "return.txt", "r") as file: + current_best_return = float(file.read().strip()) + + while True: + print( + f"Starting iteration {iter} at step {step}, current best return: {current_best_return}" + ) + + # Train for one iteration and get the results + results = ppo.train() + step = results["num_env_steps_sampled_lifetime"] + step_return = results["env_runners"].get("episode_return_mean", -np.inf) + + # Check if this is the best return we've seen and save a checkpoint if so + if step_return > current_best_return: + checkpoint_path = output_directory / "checkpoint_best" + # if this directory exists, clear it + try: + shutil.rmtree(checkpoint_path) + except FileNotFoundError: + pass + checkpoint_path.mkdir(parents=True, exist_ok=True) + ppo.save_checkpoint(str(checkpoint_path)) + with open( + checkpoint_path / f"iteration_{str(iter).zfill(6)}.txt", "w" + ) as file: + file.write(f"iter: {iter}\n") + current_best_return = step_return + with open(checkpoint_path / "return.txt", "w") as file: + file.write(f"{current_best_return}\n") + + # Save a checkpoint at the specified frequency + if iter % checkpoint_frequency == 0: + checkpoint_path = output_directory / f"checkpoint_{str(iter).zfill(6)}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + ppo.save_checkpoint(str(checkpoint_path)) + + # Delete old checkpoints if we've exceeded the number to keep + if iter > checkpoints_to_keep * checkpoint_frequency - 1: + for i in range(checkpoint_frequency): + remove_dir = ( + output_directory + / f"checkpoint_{str(iter - checkpoints_to_keep * checkpoint_frequency - i).zfill(6)}" + ) + try: + shutil.rmtree(remove_dir) + except FileNotFoundError: + pass + + # Check if we've reached the total timesteps for training + if step > total_timesteps: + break + + iter += 1 + + +if __name__ == "__main__": + import argparse + import sys + + import yaml + from avs_rl_tools.utils import sanitize_np + + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + required=True, + help="Output directory for logs and checkpoints.", + ) + parser.add_argument( + "-e", + "--env", + type=str, + default="nadir_science:nadir_science", + help="Environment to train on. Should be in the format 'module_name:benchmark_name'.", + ) + parser.add_argument( + "-j", + "--num_env_runners", + type=int, + default=-1, + help="Number of environments to train with. If -1, use all available cores minus one.", + ) + parser.add_argument( + "-r", + "--restart", + action="store_true", + help="Restart run, deleting existing checkpoints.", + ) + parser.add_argument( + "-t", + "--temp_dir", + type=str, + default="/tmp", + help="Temporary directory for Ray and model checkpoints.", + ) + parser.add_argument( + "-s", + "--total_timesteps", + type=int, + default=20_000_000, + help="Total number of environment steps to train for.", + ) + parser.add_argument( + "--checkpoint_frequency", + type=int, + default=1, + help="Frequency of checkpointing.", + ) + parser.add_argument( + "--checkpoints_to_keep", + type=int, + default=2, + help="Number of checkpoints to keep.", + ) + args = parser.parse_args() + + # Process CLI + output_dir = Path(args.output_dir) + temp_dir = args.temp_dir + if args.num_env_runners != -1: + num_env_runners = args.num_env_runners + else: + num_env_runners = get_available_cores() - 1 + print(f"Tensorboard logging: tensorboard --logdir {output_dir}") + + # Dynamically import the specified benchmark environment + module_name, benchmark_name = args.env.split(":") + module = __import__(f"{module_name}", fromlist=[benchmark_name]) + benchmark_env = getattr(module, benchmark_name) + + # Restart run if specified + if args.restart: + print(f"Restarting run, deleting existing checkpoints in {output_dir}") + try: + shutil.rmtree(output_dir) + except FileNotFoundError: + pass + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Make a new PPO model if no checkpoint exists + if not (output_dir / "checkpoint_best").exists(): + # Record training parameters + with open(output_dir / f"params.txt", "w") as file: + yaml.dump( + sanitize_np( + { + k: getattr(benchmark_env, k) + for k in benchmark_env.__dataclass_fields__.keys() + } + ), + file, + ) + ppo = create_new_model( + output_directory=output_dir, + **{ + k: getattr(benchmark_env, k) + for k in benchmark_env.__dataclass_fields__.keys() + }, + num_env_runners=num_env_runners, + temp_dir=temp_dir, + ) + # Otherwise, load the existing model and continue training + else: + ppo = load_existing_model( + output_directory=output_dir, + temp_dir=temp_dir, + ) + + train( + ppo=ppo, + output_directory=output_dir, + checkpoint_frequency=args.checkpoint_frequency, + checkpoints_to_keep=args.checkpoints_to_keep, + total_timesteps=args.total_timesteps, + ) + + sys.exit(0) diff --git a/benchmarks/nadir_science.py b/benchmarks/nadir_science.py new file mode 100644 index 00000000..b4c61b6d --- /dev/null +++ b/benchmarks/nadir_science.py @@ -0,0 +1,132 @@ +import numpy as np +from benchmark import BenchmarkEnv +from ray.rllib.core.rl_module.rl_module import RLModuleSpec + +from bsk_rl import act, data, obs, sats, scene +from bsk_rl.sim import dyn, fsw + + +class ScanningSatellite(sats.AccessSatellite): + observation_spec = [ + obs.SatProperties( + dict(prop="storage_level_fraction"), + dict(prop="battery_charge_fraction"), + dict(prop="wheel_speeds_fraction"), + ), + obs.OpportunityProperties( + dict(prop="opportunity_open", norm=5700), + dict(prop="opportunity_close", norm=5700), + type="ground_station", + n_ahead_observe=1, + ), + obs.Eclipse(norm=5700), + obs.Time(), + ] + action_spec = [ + act.Scan(duration=180.0), + act.Charge(duration=120.0), + act.Downlink(duration=60.0), + act.Desat(duration=60.0), + ] + dyn_type = (dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel) + fsw_type = (fsw.ContinuousImagingFSWModel, fsw.BasicFSWModel) + + +sat_args = dict( + # Data + dataStorageCapacity=5000 * 8e6, # bits + storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6, + instrumentBaudRate=0.5 * 8e6, + transmitterBaudRate=-50 * 8e6, + # Power + batteryStorageCapacity=200 * 3600, # W*s + storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600, + basePowerDraw=-10.0, # W + instrumentPowerDraw=-30.0, # W + transmitterPowerDraw=-25.0, # W + thrusterPowerDraw=-80.0, # W + panelArea=0.25, + # Attitude + imageAttErrorRequirement=0.1, + imageRateErrorRequirement=0.1, + disturbance_vector=lambda: np.random.normal(scale=0.003, size=3), # N*m + maxWheelSpeed=5000.0, # RPM + wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3), + desatAttitude="nadir", +) + +DURATION = 5 * 5700.0 # About 5 orbits + + +def episode_data_callback(env): + reward = env.rewarder.cum_reward + reward = sum(reward.values()) / len(reward) + orbits = env.simulator.sim_time / (95 * 60) + + data = dict( + reward=reward, + orbits_complete=orbits, + ) + if orbits > 0: + data["reward_per_orbit"] = reward / orbits + if orbits < DURATION / (95 * 60): + data["orbits_complete_partial_only"] = orbits + + return data + + +def satellite_data_callback(env, sat): + orbits = env.simulator.sim_time / (95 * 60) + data = dict( + alive=float(sat.is_alive()), + rw_status_valid=float(sat.dynamics.rw_speeds_valid()), + battery_status_valid=float(sat.dynamics.battery_valid()), + orbits_complete=orbits, + storage_level=sat.dynamics.storage_level_fraction, + battery_level=sat.dynamics.battery_charge_fraction, + ) + + return data + + +env_args = dict( + satellites=[ScanningSatellite("scanning_sat", sat_args=sat_args)], + scenario=scene.UniformNadirScanning(value_per_second=1 / DURATION), + rewarder=data.ScanningTimeReward(), + time_limit=DURATION, + failure_penalty=-1.0, + terminate_on_time_limit=True, + episode_data_callback=episode_data_callback, + satellite_data_callback=satellite_data_callback, +) + + +policies = {"policy"} +policy_mapping_fn = lambda agent_id, *args, **kwargs: "policy" +module_specs = { + "policy": RLModuleSpec( + model_config_dict={ + "use_lstm": False, + "fcnet_hiddens": [512, 512], + "vf_share_layers": False, + }, + ), +} + +training_args = dict( + lr=3e-5, + gamma=0.99999, + train_batch_size=3200, + num_sgd_iter=10, + use_kl_loss=False, + clip_param=0.1, + grad_clip=0.5, +) + +nadir_science = BenchmarkEnv( + env_args=env_args, + policies=policies, + policy_mapping_fn=policy_mapping_fn, + module_specs=module_specs, + training_args=training_args, +) diff --git a/benchmarks/rso_inspection.py b/benchmarks/rso_inspection.py new file mode 100644 index 00000000..99308d95 --- /dev/null +++ b/benchmarks/rso_inspection.py @@ -0,0 +1,327 @@ +from functools import partial + +import numpy as np +from benchmark import BenchmarkEnv +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from Basilisk.utilities.orbitalMotion import elem2rv +from Basilisk.utilities.RigidBodyKinematics import C2MRP + +from bsk_rl import act, data, obs, sats, scene +from bsk_rl.obs.relative_observations import rso_imaged_regions +from bsk_rl.sim import dyn, fsw +from bsk_rl.utils.orbital import ( + fibonacci_sphere, + random_orbit, + random_unit_vector, + relative_to_chief, + rv2HN, +) + + +def sun_hat_chief(self, other): + r_SN_N = ( + self.simulator.world.gravFactory.spiceObject.planetStateOutMsgs[ + self.simulator.world.sun_index + ] + .read() + .PositionVector + ) + r_SN_N = np.array(r_SN_N) + r_SN_N_hat = r_SN_N / np.linalg.norm(r_SN_N) + HN = other.dynamics.HN + return HN @ r_SN_N_hat + + +def SN(sat): + """Returns the Planet-Sun Hill frame""" + planet_state = sat.simulator.world.gravFactory.spiceObject.planetStateOutMsgs[ + sat.simulator.world.sun_index + ].read() + r_NS_N = -np.array(planet_state.PositionVector) + v_NS_N = -np.array(planet_state.VelocityVector) + + return rv2HN(r_N=r_NS_N, v_N=v_NS_N) + + +def eccentricity_vec_N(sat): + h = h_vec_N(sat) + r_N = sat.dynamics.r_BN_N + v_N = sat.dynamics.v_BN_N + mu = sat.dynamics.mu + e_N = np.cross(v_N, h) / mu - r_N / np.linalg.norm(r_N) + return e_N + + +def eccentricity_vec_S(sat): + return SN(sat) @ eccentricity_vec_N(sat) + + +def h_vec_N(sat): + r_N = sat.dynamics.r_BN_N + v_N = sat.dynamics.v_BN_N + h_N = np.cross(r_N, v_N) + return h_N + + +def h_vec_S(sat): + return SN(sat) @ h_vec_N(sat) + + +class InspectorSat(sats.Satellite): + observation_spec = [ + obs.SatProperties( + dict(prop="dv_available", norm=10), + dict(prop="eccentricity_vec_S", fn=eccentricity_vec_S, norm=0.04184), + dict(prop="h_vec_S", fn=h_vec_S, norm=5e10), + dict( + prop="r_BN_S", + fn=lambda sat: SN(sat) @ sat.dynamics.r_BN_N, + norm=7e6, + ), + ), + obs.RelativeProperties( + dict(prop="r_DC_Hc", norm=500), + dict(prop="v_DC_Hc", norm=5), + dict( + prop="imaged_hill", + fn=partial( + rso_imaged_regions, + region_centers=fibonacci_sphere(15), + frame="chief_hill", + ), + ), + dict(prop="sun_hat_Hc", fn=sun_hat_chief), + chief_name="RSO", + ), + obs.Eclipse(norm=5700), + obs.Time(), + ] + + action_spec = [ + act.ImpulsiveThrustHill( + chief_name="RSO", + max_dv=1.0, + max_drift_duration=5700.0 * 2, + fsw_action="action_inspect_rso", + ) + ] + dyn_type = ( + dyn.MaxRangeDynModel, + dyn.ConjunctionDynModel, + dyn.RSOInspectorDynModel, + ) + fsw_type = ( + fsw.MagicOrbitalManeuverFSWModel, + fsw.RSOInspectorFSWModel, + ) + + +inspector_sat_args = dict( + imageAttErrorRequirement=1.0, + imageRateErrorRequirement=None, + instrumentBaudRate=1, + dataStorageCapacity=1e6, + batteryStorageCapacity=1e9, + storedCharge_Init=1e9, + conjunction_radius=2.0, + dv_available_init=10.0, + max_range_radius=1000, + enforce_range_on_chief=True, + chief_name="RSO", +) + + +class RSOSat(sats.Satellite): + observation_spec = [ + obs.SatProperties(dict(prop="one", fn=lambda _: 1.0)), + ] + action_spec = [act.NadirPoint(duration=1e9)] + dyn_type = (dyn.ConjunctionDynModel, dyn.RSODynModel) + fsw_type = fsw.FSWModel + + +rso_sat_args = dict(conjunction_radius=2.0) + + +def episode_data_callback(env): + data = {} + + data["inspected_points"] = sum(env.rewarder.data.point_inspect_status.values()) + if len(env.scenario.rso_points) > 0: + data["inspected_fraction"] = env.rewarder.data.num_points_inspected / len( + env.scenario.rso_points + ) + if env.rewarder.data.num_points_illuminated > 0: + data["possible_inspected_fraction"] = ( + env.rewarder.data.num_points_inspected + / env.rewarder.data.num_points_illuminated + ) + data["all_inspected"] = env.rewarder.bonus_reward_yielded + if env.rewarder.bonus_reward_yielded: + data["all_inspected_time"] = env.rewarder.bonus_reward_time + + data["episode_duration"] = env.simulator.sim_time + + return data + + +def satellite_data_callback(env, satellite): + data = {} + + if "Inspector" in satellite.name: + fuel_used = ( + satellite.sat_args_generator["dv_available_init"] + - satellite.fsw.dv_available + ) + data["fuel_used"] = fuel_used + data["mean_thrust_size"] = fuel_used / max(satellite.fsw.thrust_count, 1) + data["mean_thrust_duration"] = env.simulator.sim_time / max( + satellite.fsw.thrust_count, 1 + ) + data["thrust_count"] = satellite.fsw.thrust_count + data["alive"] = satellite.is_alive() + data["collision"] = not satellite.dynamics.conjunction_valid() + data["out_of_range"] = not satellite.dynamics.range_valid() + data["fuel_remaining"] = satellite.fsw.fuel_remaining() + chief = [sat for sat in env.satellites if sat.name == "RSO"][0] + data["distance_to_rso"] = np.linalg.norm( + np.array(satellite.dynamics.r_BN_N) - np.array(chief.dynamics.r_BN_N) + ) + + return data + + +def sat_arg_randomizer(satellites): + # Generate the RSO orbit + R_E = 6371.0 # km + a = R_E + np.random.uniform(500, 1100) + e = np.random.uniform(0.0, min(1 - (R_E + 500) / a, (R_E + 1100) / a - 1)) + chief_orbit = random_orbit(a=a, e=e) + + inspectors = [sat for sat in satellites if "Inspector" in sat.name] + rso = [satellite for satellite in satellites if satellite.name == "RSO"][0] + + # Generate the inspector initial states. + args = {} + for inspector in inspectors: + relative_randomizer = relative_to_chief( + chief_name="RSO", + chief_orbit=chief_orbit, + deputy_relative_state={ + inspector.name: lambda: np.concatenate( + ( + random_unit_vector() * np.random.uniform(250, 750), + random_unit_vector() * np.random.uniform(0, 1.0), + ) + ), + }, + ) + args.update(relative_randomizer([rso, inspector])) + + return args + + +def rewarder_config( + fuel_penalty_weight=0.1, + completion_bonus=1.0, + inspection_reward_scale=1.0, + completion_threshold=0.90, +): + return ( + data.RSOInspectionReward( + inspection_reward_scale=inspection_reward_scale, + completion_bonus=completion_bonus, + completion_threshold=completion_threshold, + min_time_for_completion=5700.0, + ), + data.ResourceReward( + resource_fn=lambda sat: sat.fsw.dv_available + if isinstance(sat.fsw, fsw.MagicOrbitalManeuverFSWModel) + else 0.0, + reward_weight=fuel_penalty_weight, + ), + ) + + +env_args = dict( + satellites=[ + RSOSat("RSO", sat_args=rso_sat_args), + InspectorSat("Inspector-1", sat_args=inspector_sat_args), + ], + sat_arg_randomizer=sat_arg_randomizer, + scenario=scene.SphericalRSO( + n_points=100, + radius=1.0, + theta_max=np.radians(30), + range_max=250, + theta_solar_max=np.radians(60), + ), + rewarder=rewarder_config(), + time_limit=5700.0 * 10, + sim_rate=5.0, + episode_data_callback=episode_data_callback, + satellite_data_callback=satellite_data_callback, +) + +policies = {"inspector", "rso"} + + +def policy_mapping_fn(agent_id, *args, **kwargs): + if agent_id == "RSO": + return "rso" + return "inspector" + + +module_specs = { + "inspector": RLModuleSpec( + model_config_dict={ + "use_lstm": False, + "fcnet_hiddens": [512, 512], + "vf_share_layers": False, + }, + ), + "rso": RLModuleSpec( + model_config_dict={ + "use_lstm": False, + "fcnet_hiddens": [2, 2], + "vf_share_layers": False, + }, + ), +} + +training_args = dict( + lr=[[0, 1e-4], [3e6, 1e-4], [6e6, 1e-5], [9e6, 1e-6]], + gamma=0.9999, + train_batch_size=3600, + num_sgd_iter=10, + use_kl_loss=False, + clip_param=0.1, + grad_clip=1.0, + entropy_coeff=0.0, +) + +rso_inspection = BenchmarkEnv( + env_args=env_args, + policies=policies, + policy_mapping_fn=policy_mapping_fn, + module_specs=module_specs, + training_args=training_args, +) + + +if __name__ == "__main__": + from bsk_rl import ConstellationTasking + + env = ConstellationTasking( + **env_args, + log_level="INFO", + vizard_dir="/Users/markstephenson/vizard_out/simpletest.bin", + ) + + for i in range(9, 10): + env.reset(seed=i) + for _ in range(3): + actions = {"Inspector-1": np.array([0.0, 0.0, 0.0, 10]), "RSO": 0} + o, r, t1, t2, i = env.step(actions) + if (t1["Inspector-1"] or t2["Inspector-1"]) and (t1["RSO"] or t2["RSO"]): + break diff --git a/benchmarks/train.sh b/benchmarks/train.sh new file mode 100644 index 00000000..83d73875 --- /dev/null +++ b/benchmarks/train.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Allocation account name +#SBATCH --account=ucb375_asc3 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=32 +#SBATCH --time=23:59:00 +# Jobs to run, inclusive +#SBATCH --array=0-3 +#SBATCH --partition=amilan +#SBATCH --output=/scratch/alpine/%u/job_%a-%j.out +#SBATCH --mail-type=ALL +#SBATCH --qos=normal + +# Copy this script as sweep.sh to modify it for your own use. + +module purge + +echo "Loading modules" +module load python/3.10.2 +module load gcc + +echo "Activating virtual environment" +source /projects/$USER/.venv/bin/activate + +echo "Running training script" + +LOOKUP_TABLE=( + "nadir_science:nadir_science" + "aeos:aeos_single" + "aeos:aeos_constellation" + "rso_inspection:rso_inspection" +) + +OUTPUT_DIR="/scratch/alpine/mast9128/26_05_07_benchmarks" +TEMP_DIR="/scratch/alpine/mast9128/tmp" + +ENV_ID=${LOOKUP_TABLE[$SLURM_ARRAY_TASK_ID]} +IFS=':' read -r ENV ENV_NAME <<< "$ENV_ID" + +for iter in {1..10}; do + python3 /projects/$USER/bsk_rl/benchmarks/benchmark.py -o $OUTPUT_DIR/$ENV_NAME -e $ENV_ID -t $TEMP_DIR + [ $? -eq 0 ] && break +done + +echo "== End of Job ==" diff --git a/examples/async_multiagent_training.ipynb b/examples/async_multiagent_training.ipynb index 20ed4813..c5231e23 100644 --- a/examples/async_multiagent_training.ipynb +++ b/examples/async_multiagent_training.ipynb @@ -26,6 +26,7 @@ "outputs": [], "source": [ "from importlib.metadata import version\n", + "\n", "version(\"ray\") # Parent package of RLlib" ] }, @@ -49,25 +50,31 @@ "from bsk_rl import act, data, obs, sats, scene\n", "from bsk_rl.sim import dyn, fsw\n", "\n", - "class ScanningDownlinkDynModel(dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel):\n", + "\n", + "class ScanningDownlinkDynModel(\n", + " dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel\n", + "):\n", " # Define some custom properties to be accessed in the state\n", " @property\n", " def instrument_pointing_error(self) -> float:\n", - " r_BN_P_unit = self.r_BN_P/np.linalg.norm(self.r_BN_P) \n", + " r_BN_P_unit = self.r_BN_P / np.linalg.norm(self.r_BN_P)\n", " c_hat_P = self.satellite.fsw.c_hat_P\n", " return np.arccos(np.dot(-r_BN_P_unit, c_hat_P))\n", - " \n", + "\n", " @property\n", " def solar_pointing_error(self) -> float:\n", - " a = self.world.gravFactory.spiceObject.planetStateOutMsgs[\n", - " self.world.sun_index\n", - " ].read().PositionVector\n", + " a = (\n", + " self.world.gravFactory.spiceObject.planetStateOutMsgs[self.world.sun_index]\n", + " .read()\n", + " .PositionVector\n", + " )\n", " a_hat_N = a / np.linalg.norm(a)\n", " nHat_B = self.satellite.sat_args[\"nHat_B\"]\n", " NB = np.transpose(self.BN)\n", " nHat_N = NB @ nHat_B\n", " return np.arccos(np.dot(nHat_N, a_hat_N))\n", "\n", + "\n", "class ScanningSatellite(sats.AccessSatellite):\n", " observation_spec = [\n", " obs.SatProperties(\n", @@ -75,7 +82,7 @@ " dict(prop=\"battery_charge_fraction\"),\n", " dict(prop=\"wheel_speeds_fraction\"),\n", " dict(prop=\"instrument_pointing_error\", norm=np.pi),\n", - " dict(prop=\"solar_pointing_error\", norm=np.pi)\n", + " dict(prop=\"solar_pointing_error\", norm=np.pi),\n", " ),\n", " obs.OpportunityProperties(\n", " dict(prop=\"opportunity_open\", norm=5700),\n", @@ -92,33 +99,37 @@ " act.Desat(duration=45.0),\n", " ]\n", " dyn_type = ScanningDownlinkDynModel\n", - " fsw_type = fsw.ContinuousImagingFSWModel\n", + " fsw_type = (fsw.ContinuousImagingFSWModel, fsw.BasicFSWModel)\n", + "\n", "\n", - "sats = [ScanningSatellite(\n", - " f\"Scanner-{i+1}\",\n", - " sat_args=dict(\n", - " # Data\n", - " dataStorageCapacity=5000 * 8e6, # bits\n", - " storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6,\n", - " instrumentBaudRate=0.5 * 8e6,\n", - " transmitterBaudRate=-50 * 8e6,\n", - " # Power\n", - " batteryStorageCapacity=200 * 3600, # W*s\n", - " storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600,\n", - " basePowerDraw=-10.0, # W\n", - " instrumentPowerDraw=-30.0, # W\n", - " transmitterPowerDraw=-25.0, # W\n", - " thrusterPowerDraw=-80.0, # W\n", - " panelArea=0.25,\n", - " # Attitude\n", - " imageAttErrorRequirement=0.1,\n", - " imageRateErrorRequirement=0.1,\n", - " disturbance_vector=lambda: np.random.normal(scale=0.0001, size=3), # N*m\n", - " maxWheelSpeed=6000.0, # RPM\n", - " wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),\n", - " desatAttitude=\"nadir\",\n", + "sats = [\n", + " ScanningSatellite(\n", + " f\"Scanner-{i + 1}\",\n", + " sat_args=dict(\n", + " # Data\n", + " dataStorageCapacity=5000 * 8e6, # bits\n", + " storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6,\n", + " instrumentBaudRate=0.5 * 8e6,\n", + " transmitterBaudRate=-50 * 8e6,\n", + " # Power\n", + " batteryStorageCapacity=200 * 3600, # W*s\n", + " storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600,\n", + " basePowerDraw=-10.0, # W\n", + " instrumentPowerDraw=-30.0, # W\n", + " transmitterPowerDraw=-25.0, # W\n", + " thrusterPowerDraw=-80.0, # W\n", + " panelArea=0.25,\n", + " # Attitude\n", + " imageAttErrorRequirement=0.1,\n", + " imageRateErrorRequirement=0.1,\n", + " disturbance_vector=lambda: np.random.normal(scale=0.0001, size=3), # N*m\n", + " maxWheelSpeed=6000.0, # RPM\n", + " wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),\n", + " desatAttitude=\"nadir\",\n", + " ),\n", " )\n", - ") for i in range(4)]" + " for i in range(4)\n", + "]" ] }, { @@ -157,7 +168,7 @@ "duration = 5 * 5700.0 # About 5 orbits\n", "env_args = dict(\n", " satellites=sats,\n", - " scenario=scene.UniformNadirScanning(value_per_second=1/duration),\n", + " scenario=scene.UniformNadirScanning(value_per_second=1 / duration),\n", " rewarder=data.ScanningTimeReward(),\n", " time_limit=duration,\n", " failure_penalty=-1.0,\n", @@ -265,8 +276,7 @@ " \"p0\": RLModuleSpec(),\n", " }\n", " ),\n", - ")\n", - "\n" + ")\n" ] }, { @@ -372,7 +382,9 @@ "source": [ "config.training(\n", " learner_connector=lambda obs_space, act_space: (\n", - " discounting.MakeAddedStepActionValid(expected_train_batch_size=config.train_batch_size),\n", + " discounting.MakeAddedStepActionValid(\n", + " expected_train_batch_size=config.train_batch_size\n", + " ),\n", " discounting.CondenseMultiStepActions(),\n", " ),\n", ")" diff --git a/examples/curriculum_learning.ipynb b/examples/curriculum_learning.ipynb index a439ff10..79dab8b2 100644 --- a/examples/curriculum_learning.ipynb +++ b/examples/curriculum_learning.ipynb @@ -225,7 +225,7 @@ " act.Desat(duration=180.0), # Desaturate for 3 minute\n", " ]\n", " dyn_type = CustomDynamics\n", - " fsw_type = fsw.ContinuousImagingFSWModel" + " fsw_type = (fsw.ContinuousImagingFSWModel, fsw.BasicFSWModel)" ] }, { diff --git a/examples/rllib_training.ipynb b/examples/rllib_training.ipynb index 25bbf36f..e9317177 100644 --- a/examples/rllib_training.ipynb +++ b/examples/rllib_training.ipynb @@ -29,6 +29,7 @@ "outputs": [], "source": [ "from importlib.metadata import version\n", + "\n", "version(\"ray\") # Parent package of RLlib" ] }, @@ -54,25 +55,31 @@ "from bsk_rl import act, data, obs, sats, scene\n", "from bsk_rl.sim import dyn, fsw\n", "\n", - "class ScanningDownlinkDynModel(dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel):\n", + "\n", + "class ScanningDownlinkDynModel(\n", + " dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel\n", + "):\n", " # Define some custom properties to be accessed in the state\n", " @property\n", " def instrument_pointing_error(self) -> float:\n", - " r_BN_P_unit = self.r_BN_P/np.linalg.norm(self.r_BN_P) \n", + " r_BN_P_unit = self.r_BN_P / np.linalg.norm(self.r_BN_P)\n", " c_hat_P = self.satellite.fsw.c_hat_P\n", " return np.arccos(np.dot(-r_BN_P_unit, c_hat_P))\n", - " \n", + "\n", " @property\n", " def solar_pointing_error(self) -> float:\n", - " a = self.world.gravFactory.spiceObject.planetStateOutMsgs[\n", - " self.world.sun_index\n", - " ].read().PositionVector\n", + " a = (\n", + " self.world.gravFactory.spiceObject.planetStateOutMsgs[self.world.sun_index]\n", + " .read()\n", + " .PositionVector\n", + " )\n", " a_hat_N = a / np.linalg.norm(a)\n", " nHat_B = self.satellite.sat_args[\"nHat_B\"]\n", " NB = np.transpose(self.BN)\n", " nHat_N = NB @ nHat_B\n", " return np.arccos(np.dot(nHat_N, a_hat_N))\n", "\n", + "\n", "class ScanningSatellite(sats.AccessSatellite):\n", " observation_spec = [\n", " obs.SatProperties(\n", @@ -80,7 +87,7 @@ " dict(prop=\"battery_charge_fraction\"),\n", " dict(prop=\"wheel_speeds_fraction\"),\n", " dict(prop=\"instrument_pointing_error\", norm=np.pi),\n", - " dict(prop=\"solar_pointing_error\", norm=np.pi)\n", + " dict(prop=\"solar_pointing_error\", norm=np.pi),\n", " ),\n", " obs.OpportunityProperties(\n", " dict(prop=\"opportunity_open\", norm=5700),\n", @@ -98,7 +105,7 @@ " act.Desat(duration=60.0),\n", " ]\n", " dyn_type = ScanningDownlinkDynModel\n", - " fsw_type = fsw.ContinuousImagingFSWModel" + " fsw_type = (fsw.ContinuousImagingFSWModel, fsw.BasicFSWModel)" ] }, { @@ -138,7 +145,7 @@ " maxWheelSpeed=6000.0, # RPM\n", " wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),\n", " desatAttitude=\"nadir\",\n", - " )\n", + " ),\n", ")" ] }, @@ -159,7 +166,7 @@ "duration = 5 * 5700.0 # About 5 orbits\n", "env_args = dict(\n", " satellite=sat,\n", - " scenario=scene.UniformNadirScanning(value_per_second=1/duration),\n", + " scenario=scene.UniformNadirScanning(value_per_second=1 / duration),\n", " rewarder=data.ScanningTimeReward(),\n", " time_limit=duration,\n", " failure_penalty=-1.0,\n", @@ -203,7 +210,7 @@ " data[\"reward_per_orbit\"] = reward / orbits\n", " if not env.satellite.is_alive():\n", " data[\"orbits_complete_partial_only\"] = orbits\n", - " \n", + "\n", " return data" ] }, @@ -247,7 +254,7 @@ "config = (\n", " PPOConfig()\n", " .training(**training_args)\n", - " .env_runners(num_env_runners=N_CPUS-1, sample_timeout_s=1000.0)\n", + " .env_runners(num_env_runners=N_CPUS - 1, sample_timeout_s=1000.0)\n", " .environment(\n", " env=\"SatelliteTasking-RLlib\",\n", " env_config=dict(**env_args, episode_data_callback=episode_data_callback),\n", @@ -300,7 +307,7 @@ " config=config.to_dict(),\n", " stop={\"training_iteration\": 10}, # Adjust the number of iterations as needed\n", " checkpoint_freq=10,\n", - " checkpoint_at_end=True\n", + " checkpoint_at_end=True,\n", ")\n", "\n", "# Shutdown Ray\n", diff --git a/examples/satellite_configuration.ipynb b/examples/satellite_configuration.ipynb index 3bd39b0e..3d98c0e8 100644 --- a/examples/satellite_configuration.ipynb +++ b/examples/satellite_configuration.ipynb @@ -28,6 +28,7 @@ "import numpy as np\n", "\n", "from Basilisk.architecture import bskLogging\n", + "\n", "bskLogging.setDefaultLogLevel(bskLogging.BSK_WARNING)\n", "\n", "\n", @@ -154,7 +155,7 @@ " dict(prop=\"wheel_speeds\"),\n", " # You can specify the module to use for the observation, but it is not necessary\n", " # if only one module has for the property\n", - " dict(prop=\"battery_charge_fraction\", module=\"dynamics\"), \n", + " dict(prop=\"battery_charge_fraction\", module=\"dynamics\"),\n", " # Properties can be normalized by some constant. This is generally desirable\n", " # for RL algorithms to keep values around [-1, 1].\n", " dict(prop=\"r_BN_P\", norm=7e6),\n", @@ -164,6 +165,7 @@ " dyn_type = dyn.BasicDynamicsModel\n", " fsw_type = fsw.BasicFSWModel\n", "\n", + "\n", "env = SatelliteTasking(\n", " satellite=SatPropsSatellite(\"PropSat-1\", {}, obs_type=dict),\n", " log_level=\"CRITICAL\",\n", @@ -190,15 +192,15 @@ " @property\n", " def meaning_of_life(self):\n", " return 42\n", - " \n", + "\n", + "\n", "class BespokeSatPropsSatellite(sats.Satellite):\n", - " observation_spec = [\n", - " obs.SatProperties(dict(prop=\"meaning_of_life\"))\n", - " ]\n", + " observation_spec = [obs.SatProperties(dict(prop=\"meaning_of_life\"))]\n", " action_spec = [act.Drift()]\n", " dyn_type = dyn.BasicDynamicsModel\n", " fsw_type = BespokeFSWModel\n", "\n", + "\n", "env = SatelliteTasking(\n", " satellite=BespokeSatPropsSatellite(\"BespokeSat-1\", {}, obs_type=dict),\n", " log_level=\"CRITICAL\",\n", @@ -228,6 +230,7 @@ " dyn_type = dyn.BasicDynamicsModel\n", " fsw_type = fsw.BasicFSWModel\n", "\n", + "\n", "env = SatelliteTasking(\n", " satellite=CustomSatPropsSatellite(\"BespokeSat-1\", {}, obs_type=dict),\n", " log_level=\"CRITICAL\",\n", @@ -257,7 +260,7 @@ " observation_spec = [\n", " obs.OpportunityProperties(\n", " # Properties can be added by some default names\n", - " dict(prop=\"priority\"), \n", + " dict(prop=\"priority\"),\n", " # They can also be normalized\n", " dict(prop=\"opportunity_open\", norm=5700.0),\n", " # Or they can be specified by an arbitrary function\n", @@ -267,7 +270,8 @@ " ]\n", " action_spec = [act.Drift()]\n", " dyn_type = dyn.ImagingDynModel\n", - " fsw_type = fsw.ImagingFSWModel\n", + " fsw_type = (fsw.ImagingFSWModel, fsw.BasicFSWModel)\n", + "\n", "\n", "env = SatelliteTasking(\n", " satellite=OppPropsSatellite(\"OppSat-1\", {}, obs_type=dict),\n", @@ -300,12 +304,13 @@ "class ComposedObsSatellite(sats.Satellite):\n", " observation_spec = [\n", " obs.Eclipse(),\n", - " obs.SatProperties(dict(prop=\"battery_charge_fraction\"))\n", + " obs.SatProperties(dict(prop=\"battery_charge_fraction\")),\n", " ]\n", " action_spec = [act.Drift()]\n", " dyn_type = dyn.BasicDynamicsModel\n", " fsw_type = fsw.BasicFSWModel\n", "\n", + "\n", "env = SatelliteTasking(\n", " satellite=ComposedObsSatellite(\"PropSat-1\", {}, obs_type=dict),\n", " log_level=\"CRITICAL\",\n", @@ -404,11 +409,14 @@ " act.Charge(duration=120.0),\n", " act.Desat(duration=60.0),\n", " # One action can be included multiple time, if different settings are desired\n", - " act.Charge(duration=600.0,),\n", + " act.Charge(\n", + " duration=600.0,\n", + " ),\n", " ]\n", " dyn_type = dyn.BasicDynamicsModel\n", " fsw_type = fsw.BasicFSWModel\n", "\n", + "\n", "env = SatelliteTasking(\n", " satellite=ActionSatellite(\"ActSat-1\", {}, obs_type=dict),\n", " log_level=\"INFO\",\n", @@ -416,9 +424,9 @@ "env.reset()\n", "\n", "# Try each action; index corresponds to the order of addition\n", - "_ =env.step(0)\n", - "_ =env.step(1)\n", - "_ =env.step(2)" + "_ = env.step(0)\n", + "_ = env.step(1)\n", + "_ = env.step(2)" ] }, { @@ -470,7 +478,7 @@ " act.Image(n_ahead_image=3)\n", " ]\n", " dyn_type = dyn.ImagingDynModel\n", - " fsw_type = fsw.ImagingFSWModel\n", + " fsw_type = (fsw.ImagingFSWModel, fsw.BasicFSWModel))\n", "\n", "env = SatelliteTasking(\n", " satellite=ImageActSatellite(\"ActSat-2\", {}),\n", diff --git a/examples/simple_environment.ipynb b/examples/simple_environment.ipynb index c192ff9a..35be3df6 100644 --- a/examples/simple_environment.ipynb +++ b/examples/simple_environment.ipynb @@ -26,6 +26,7 @@ "from bsk_rl.sim import dyn, fsw\n", "\n", "from Basilisk.architecture import bskLogging\n", + "\n", "bskLogging.setDefaultLogLevel(bskLogging.BSK_WARNING)\n" ] }, @@ -51,8 +52,7 @@ "class MyScanningSatellite(sats.AccessSatellite):\n", " observation_spec = [\n", " obs.SatProperties(\n", - " dict(prop=\"storage_level_fraction\"),\n", - " dict(prop=\"battery_charge_fraction\")\n", + " dict(prop=\"storage_level_fraction\"), dict(prop=\"battery_charge_fraction\")\n", " ),\n", " obs.Eclipse(),\n", " ]\n", @@ -61,7 +61,7 @@ " act.Charge(duration=600.0), # Charge for 10 minutes\n", " ]\n", " dyn_type = dyn.ContinuousImagingDynModel\n", - " fsw_type = fsw.ContinuousImagingFSWModel" + " fsw_type = (fsw.ContinuousImagingFSWModel, fsw.BasicFSWModel)" ] }, { @@ -198,7 +198,9 @@ "source": [ "while not truncated:\n", " observation, reward, terminated, truncated, info = env.step(action=1)\n", - " print(f\"Charge level: {observation[1]:.3f} ({env.unwrapped.simulator.sim_time:.1f} seconds)\\n\\tEclipse: start: {observation[2]:.1f} end: {observation[3]:.1f}\")" + " print(\n", + " f\"Charge level: {observation[1]:.3f} ({env.unwrapped.simulator.sim_time:.1f} seconds)\\n\\tEclipse: start: {observation[2]:.1f} end: {observation[3]:.1f}\"\n", + " )" ] }, { diff --git a/examples/time_discounted_gae.ipynb b/examples/time_discounted_gae.ipynb index 8b627f96..7360f744 100644 --- a/examples/time_discounted_gae.ipynb +++ b/examples/time_discounted_gae.ipynb @@ -105,7 +105,7 @@ " act.Desat(duration=60.0),\n", " ]\n", " dyn_type = ScanningDownlinkDynModel\n", - " fsw_type = fsw.ContinuousImagingFSWModel\n", + " fsw_type = (fsw.ContinuousImagingFSWModel, fsw.BasicFSWModel)\n", "\n", "\n", "sat = ScanningSatellite(\n", diff --git a/src/bsk_rl/data/rso_inspection.py b/src/bsk_rl/data/rso_inspection.py index ed4e6cc7..639b2b19 100644 --- a/src/bsk_rl/data/rso_inspection.py +++ b/src/bsk_rl/data/rso_inspection.py @@ -158,7 +158,16 @@ def compare_log_states(self, _, logs) -> RSOInspectionData: for rso_point, log in zip( self.data.point_illuminate_status.keys(), illuminated_logs ): - if any(log): + check_log = log + start_time = ( + self.satellite.simulator.sim_time + - self.satellite.simulator.sim_rate * len(log) + ) + # Delete first two illumination points to allow forced attitude control to catch up + if start_time == 0: + check_log = log[2:] + + if any(check_log): point_illuminate_status[rso_point] = True if len(point_inspect_status) > 0: diff --git a/src/bsk_rl/obs/relative_observations.py b/src/bsk_rl/obs/relative_observations.py index 8cc09aff..cd5af610 100644 --- a/src/bsk_rl/obs/relative_observations.py +++ b/src/bsk_rl/obs/relative_observations.py @@ -111,6 +111,9 @@ def rso_imaged_regions( assert frame in ["chief_hill", "chief_body"] point_inspect_status = servicer.data_store.data.point_inspect_status + if all(not inspected for inspected in point_inspect_status.values()): + return np.zeros(len(region_centers)) + region_centers_C = [] for region_center in region_centers: if frame == "chief_hill": diff --git a/src/bsk_rl/sim/fsw/base.py b/src/bsk_rl/sim/fsw/base.py index 8783e0f5..b629b492 100644 --- a/src/bsk_rl/sim/fsw/base.py +++ b/src/bsk_rl/sim/fsw/base.py @@ -254,7 +254,7 @@ def _make_task_list(self): ] class MRPControlTask(Task): - """Task to control the satellite attitude magically (i.e. without actuators).""" + """Task to control the satellite attitude magically to the reference (i.e. without actuators).""" name = "mrpControlTask" @@ -345,11 +345,14 @@ def _setup_fsw_objects(self, **kwargs) -> None: self._add_model_to_task(self.trackingError, priority=1197) + def reset_for_action(self) -> None: + """Tracking error is enabled by default for all tasks.""" + self.fsw.simulator.enableTask(self.name + self.fsw.satellite.name) + @action def action_attitude_mrp(self, sigma_RN: np.ndarray) -> None: """Point the satellite to the specified attitude.""" self.attRefMsg.write(messaging.AttRefMsgPayload(sigma_RN=sigma_RN)) - self.simulator.enableTask(self.TrackingErrorTask.name + self.satellite.name) class BasicFSWModel(FSWModel): @@ -554,7 +557,6 @@ def action_desat(self) -> None: power sink, and enables the desaturation tasks. This action typically needs to be called multiple times to fully desaturate the wheels. """ - self.trackingError.Reset(self.simulator.sim_time_ns) self.thrDesatControl.Reset(self.simulator.sim_time_ns) self.thrDump.Reset(self.simulator.sim_time_ns) self.dynamics.thrusterPowerSink.powerStatus = 1 @@ -571,7 +573,6 @@ def action_desat(self) -> None: pass else: raise ValueError(f"{self.desatAttitude} not a valid desatAttitude") - self.simulator.enableTask(self.TrackingErrorTask.name + self.satellite.name) class MRPControlTask(Task): """Task to control the satellite attitude using reaction wheels.""" diff --git a/src/bsk_rl/sim/fsw/ground_imaging.py b/src/bsk_rl/sim/fsw/ground_imaging.py index 5af1f68c..2adeb54a 100644 --- a/src/bsk_rl/sim/fsw/ground_imaging.py +++ b/src/bsk_rl/sim/fsw/ground_imaging.py @@ -11,7 +11,7 @@ ) from bsk_rl.sim import dyn -from bsk_rl.sim.fsw import BasicFSWModel, SteeringFSWModel, Task, action +from bsk_rl.sim.fsw import FSWModel, BasicFSWModel, SteeringFSWModel, Task, action from bsk_rl.utils import vizard from bsk_rl.utils.functional import default_args from bsk_rl.utils.orbital import rv2HN @@ -20,7 +20,7 @@ from bsk_rl.sim.dyn import DynamicsModelABC -class ImagingFSWModel(BasicFSWModel): +class ImagingFSWModel(FSWModel): """Extend FSW with instrument pointing and triggering control.""" @classmethod @@ -102,6 +102,9 @@ def setup_location_pointing( messaging.AttGuidMsg_C_addAuthor( self.locPoint.attGuidOutMsg, self.fsw.attGuidMsg ) + messaging.AttRefMsg_C_addAuthor( + self.locPoint.attRefOutMsg, self.fsw.attRefMsg + ) self._add_model_to_task(self.locPoint, priority=1198) @@ -188,15 +191,11 @@ def action_downlink(self) -> None: baud rate. The transmitter power sink will be active as long as the task is enabled. """ self.hillPoint.Reset(self.simulator.sim_time_ns) - self.trackingError.Reset(self.simulator.sim_time_ns) self.dynamics.transmitter.dataStatus = 1 self.dynamics.transmitterPowerSink.powerStatus = 1 self.simulator.enableTask( BasicFSWModel.NadirPointTask.name + self.satellite.name ) - self.simulator.enableTask( - BasicFSWModel.TrackingErrorTask.name + self.satellite.name - ) class ContinuousImagingFSWModel(ImagingFSWModel): diff --git a/src/bsk_rl/sim/fsw/rso_inspection.py b/src/bsk_rl/sim/fsw/rso_inspection.py index a7e61e15..edb82e13 100644 --- a/src/bsk_rl/sim/fsw/rso_inspection.py +++ b/src/bsk_rl/sim/fsw/rso_inspection.py @@ -49,6 +49,9 @@ def setup_location_pointing( messaging.AttGuidMsg_C_addAuthor( self.locPoint.attGuidOutMsg, self.fsw.attGuidMsg ) + messaging.AttRefMsg_C_addAuthor( + self.locPoint.attRefOutMsg, self.fsw.attRefMsg + ) self._add_model_to_task(self.locPoint, priority=1198) diff --git a/src/bsk_rl/sim/world.py b/src/bsk_rl/sim/world.py index 2b63e2c9..2b0bcaea 100644 --- a/src/bsk_rl/sim/world.py +++ b/src/bsk_rl/sim/world.py @@ -49,7 +49,7 @@ _DATA_FETCHER_API = True except ImportError: - bskPath = __path__[0] + bsk_path = __path__[0] _DATA_FETCHER_API = False logger = logging.getLogger(__name__) @@ -214,14 +214,6 @@ def setup_ephem_object(self, priority: int = 988, **kwargs) -> None: self.world_task_name, self.ephemConverter, ModelPriority=priority ) - def __del__(self) -> None: - """Log when world is deleted and unload SPICE.""" - super().__del__() - try: - self.gravFactory.unloadSpiceKernels() - except AttributeError: - pass - class EclipseWorldModel(WorldModel): def __init__(self, *args, **kwargs) -> None: diff --git a/src/bsk_rl/utils/orbital.py b/src/bsk_rl/utils/orbital.py index fb2a3ed6..246762d9 100644 --- a/src/bsk_rl/utils/orbital.py +++ b/src/bsk_rl/utils/orbital.py @@ -505,7 +505,7 @@ def _generate_eclipses(self, t: float) -> None: self._eclipse_search_time = t - def next_eclipse(self, t: float, max_tries: int = 100) -> tuple[float, float]: + def next_eclipse(self, t: float, max_tries: int = 2) -> tuple[float, float]: """Find the soonest eclipse transitions. The returned values are not necessarily from the same eclipse event, such as @@ -519,7 +519,7 @@ def next_eclipse(self, t: float, max_tries: int = 100) -> tuple[float, float]: eclipse_start: Nearest upcoming eclipse beginning eclipse_end: Nearest upcoming eclipse end """ - for i in range(max_tries): + for i in range(max_tries + 1): if any([t_start > t for t_start in self._eclipse_starts]) and any( [t_end > t for t_end in self._eclipse_ends] ): @@ -529,8 +529,11 @@ def next_eclipse(self, t: float, max_tries: int = 100) -> tuple[float, float]: eclipse_end = min([t_end for t_end in self._eclipse_ends if t_end > t]) return eclipse_start, eclipse_end - self._generate_eclipses(t + i * self.dt * 10) + self._generate_eclipses(t + i * 6000) + logger.warning( + f"Could not find eclipse transitions in next {self._eclipse_search_time - t:.1f} seconds" + ) return 1.0, 1.0 @property @@ -562,13 +565,6 @@ def r_BP_P(self) -> interp1d: fill_value="extrapolate", ) - def __del__(self) -> None: - """Unload spice kernels when object is deleted.""" - try: - self.gravFactory.unloadSpiceKernels() - except AttributeError: - pass - def lla2ecef(lat: float, long: float, radius: float): """Project LLA to Earth Centered, Earth Fixed location. diff --git a/tests/unittest/sim/test_world.py b/tests/unittest/sim/test_world.py index 53d8f633..8de4e2dc 100644 --- a/tests/unittest/sim/test_world.py +++ b/tests/unittest/sim/test_world.py @@ -74,17 +74,13 @@ def test_omega_PN_N(self): @patch(baseworld + "setup_gravity_bodies") @patch(baseworld + "setup_ephem_object") def test_setup_and_delete(self, grav_set, epoch_set): - world = WorldModel(MagicMock(), 1.0) + world = WorldModel(MagicMock(), 1.0) # noqa: F841 for setter in (grav_set, epoch_set): setter.assert_called_once() - unload_function = MagicMock() - world.gravFactory = MagicMock(unloadSpiceKernels=unload_function) - del world - unload_function.assert_called_once() @patch(baseworld + "_setup_world_objects", MagicMock()) @patch(module + "simIncludeGravBody", MagicMock()) - def testsetup_gravity_bodies(self): + def test_setup_gravity_bodies(self): # Smoke test world = WorldModel(MagicMock(), 1.0) world.simulator = MagicMock() @@ -93,7 +89,7 @@ def testsetup_gravity_bodies(self): @patch(baseworld + "_setup_world_objects", MagicMock()) @patch(module + "ephemerisConverter", MagicMock()) - def testsetup_epoch_object(self): + def test_setup_epoch_object(self): # Smoke test world = WorldModel(MagicMock(), 1.0) world.simulator = MagicMock() @@ -109,7 +105,7 @@ class TestAtmosphereWorldModel: @patch(baseworld + "_setup_world_objects", MagicMock()) @patch(module + "exponentialAtmosphere", MagicMock()) - def testsetup_atmosphere_density_model(self): + def test_setup_atmosphere_density_model(self): # Smoke test world = AtmosphereWorldModel(MagicMock(), 1.0) world.simulator = MagicMock() @@ -129,7 +125,7 @@ class TestEclipseWorldModel: @patch(baseworld + "_setup_world_objects", MagicMock()) @patch(module + "eclipse", MagicMock()) - def testsetup_eclipse_object(self): + def test_setup_eclipse_object(self): # Smoke test world = EclipseWorldModel(MagicMock(), 1.0) world.simulator = MagicMock() @@ -151,7 +147,7 @@ def test_setup_world_objects(self, ground_set): @patch(groundworld + "_setup_world_objects", MagicMock()) @patch(groundworld + "_create_ground_station") - def testsetup_ground_locations(self, mock_gs_create): + def test_setup_ground_locations(self, mock_gs_create): world = GroundStationWorldModel(MagicMock(), 1.0) world.setup_ground_locations([dict(a=1), dict(b=2)], 1000.0, 1.0, 1000.0) mock_gs_create.assert_has_calls(