Conversation
f4e79b1 to
dfae80a
Compare
There was a problem hiding this comment.
Pull request overview
Adds a lightweight benchmark harness (env templates + a Ray RLlib PPO training entrypoint) to support repeatable training runs against specific benchmark environments, and fixes a Basilisk support-data path fallback bug.
Changes:
- Fix Basilisk
dataFetcherImportError fallback path variable naming inWorldModelsetup. - Add
benchmarks/train.pyRLlib PPO training script with checkpointing and dynamic benchmark loading. - Add a first benchmark environment (
nadir_science) plus a smallBenchmarkdataclass template.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
src/bsk_rl/sim/world.py |
Fixes fallback Basilisk path variable used when dataFetcher API is unavailable. |
benchmarks/train.py |
New RLlib PPO training/continue script for benchmark runs (checkpoint mgmt, dynamic env import). |
benchmarks/nadir_science.py |
Defines a benchmark environment configuration (satellite model + env/training args). |
benchmarks/env_template.py |
Introduces a simple Benchmark dataclass container for env/training configuration. |
Comments suppressed due to low confidence (2)
benchmarks/train.py:359
- The script initializes Ray and a PPO algorithm but never calls
ppo.stop()/ray.shutdown()before exiting. This can leave worker processes running (especially in interactive or repeated benchmark runs) and can hold onto temp dirs/object store resources. Consider adding atry/finallyaround training to ensure cleanup.
benchmarks/nadir_science.py:133 - These
print(...)statements will execute at import time, which is noisy when using the dynamic import inbenchmarks/train.pyand makes it harder to use these benchmarks as a library. Consider removing them or switching to logging guarded byif __name__ == "__main__"/ a verbosity flag.
training_args = dict(
lr=3e-5,
gamma=0.999,
train_batch_size=1000,
num_sgd_iter=10,
use_kl_loss=False,
clip_param=0.1,
grad_clip=0.5,
)
nadir_science_benchmark = Benchmark(
env_args=env_args,
policies=policies,
policy_mapping_fn=policy_mapping_fn,
module_specs=module_specs,
training_args=training_args,
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ray.init( | ||
| ignore_reinit_error=True, | ||
| num_cpus=get_available_cores(), | ||
| object_store_memory=2_000_000_000, # 2 GB | ||
| _temp_dir=temp_dir, | ||
| ) | ||
| config = ( | ||
| PPOConfig() | ||
| .training(**training_args) | ||
| .env_runners( | ||
| num_env_runners=num_env_runners, | ||
| sample_timeout_s=50000.0, |
There was a problem hiding this comment.
The print statements dumping module_specs (and the hard-coded RLModuleSpec dict) look like debugging leftovers and will spam logs on every run, especially on clusters. Consider removing them or switching to logger.debug(...) behind a verbosity flag.
|
|
||
|
|
||
| def train( |
There was a problem hiding this comment.
The local variable name iter shadows Python’s built-in iter() function, which can be confusing and makes debugging harder (and can break code if iter is later needed in this scope). Rename to something like iteration/checkpoint_iter in load_existing_model/train.
| "--env", | ||
| type=str, | ||
| default="nadir_science:nadir_science_benchmark", |
There was a problem hiding this comment.
avs_rl_tools is imported here but it is not listed in pyproject.toml dependencies or optional extras, so running benchmarks/train.py will fail in a clean install. Either vendor/inline the small sanitize_np functionality, move it into this repo, or add the package to an appropriate optional dependency group.
| except ImportError: | ||
| bskPath = __path__[0] | ||
| bsk_path = __path__[0] | ||
| _DATA_FETCHER_API = False |
There was a problem hiding this comment.
The _DATA_FETCHER_API = False fallback path (where Basilisk’s dataFetcher isn’t available) is currently untested, and this change fixes a name mismatch that would only surface in that branch. Consider adding a unit test that forces _DATA_FETCHER_API to False (e.g., via monkeypatch) and asserts setup_gravity_bodies uses the fallback paths without raising.
| training_args={}, | ||
| temp_dir="/tmp", | ||
| ): | ||
| """Configure a PPO model for training with sMDP discounting and asynchronous multiagent actions.""" | ||
|
|
There was a problem hiding this comment.
training_args={} as a default argument is a mutable default and will be shared across calls to create_new_model, which can lead to surprising cross-run configuration leakage. Use training_args=None and initialize to {} inside the function (or use an immutable mapping type).
| training_args={}, | |
| temp_dir="/tmp", | |
| ): | |
| """Configure a PPO model for training with sMDP discounting and asynchronous multiagent actions.""" | |
| training_args=None, | |
| temp_dir="/tmp", | |
| ): | |
| """Configure a PPO model for training with sMDP discounting and asynchronous multiagent actions.""" | |
| if training_args is None: | |
| training_args = {} |
| # TODO remove, for cluster only | ||
| torch.set_num_threads(11) | ||
| os.environ["MKL_NUM_THREADS"] = "11" |
There was a problem hiding this comment.
This script forces torch.set_num_threads(11) and MKL_NUM_THREADS=11 unconditionally, which can severely underutilize or oversubscribe CPUs depending on the machine/SLURM allocation and makes runs non-reproducible across environments. Consider deriving this from get_available_cores() (and/or a CLI flag) and only setting it when explicitly requested.
fe9180d to
0b7b2b1
Compare
5828acc to
ed7c2ff
Compare
Description
Closes #XXX
Type of change
How should this pull request be reviewed?
How Has This Been Tested?
Please describe how tests have been updated to verify your changes.
Future Work
What future tasks are needed, if any?
Checklist