From 663ef0c688c5d6743b3f52287a7e3edfb982f153 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Mon, 3 Feb 2025 22:34:42 -0800 Subject: [PATCH 01/17] work --- .vscode/settings.json | 6 +- docs/api_usage.md | 2 +- examples/alphazero/network.py | 177 ++++++++++++++++--------- examples/alphazero/stateful_network.py | 136 +++++++++++++++++++ examples/alphazero/train.py | 26 ++-- examples/minatar-ppo/requirements.txt | 4 +- examples/minatar-ppo/train.py | 156 ++++++++++++---------- requirements/requirements-dev.txt | 3 +- 8 files changed, 351 insertions(+), 159 deletions(-) create mode 100644 examples/alphazero/stateful_network.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 14ed2a311..032139fa8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -24,7 +24,9 @@ "--config", "pyproject.toml", "--ignore", "E203,E501,W503" ], - "python.linting.flake8Enabled": true, + "python.linting.flake8Enabled": false, "python.defaultInterpreterPath": "${workspace}/venv/bin/python3", - "python.terminal.activateEnvInCurrentTerminal": true + "python.terminal.activateEnvInCurrentTerminal": true, + "python.linting.enabled": false, + "python.linting.pylintEnabled": true } diff --git a/docs/api_usage.md b/docs/api_usage.md index f8e647d11..c3a877bec 100644 --- a/docs/api_usage.md +++ b/docs/api_usage.md @@ -68,7 +68,7 @@ init_fn = jax.jit(jax.vmap(env.init)) step_fn = jax.jit(jax.vmap(env.step)) # Prepare baseline model -# Note that it additionaly requires Haiku library ($ pip install dm-haiku) +# Note that it additionaly requires equinox library ($ pip install equinox) model_id = "go_9x9_v0" model = pgx.make_baseline_model(model_id) diff --git a/examples/alphazero/network.py b/examples/alphazero/network.py index e6ca93004..4466f74e8 100644 --- a/examples/alphazero/network.py +++ b/examples/alphazero/network.py @@ -1,93 +1,138 @@ # We referred to Haiku's ResNet implementation: # https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/nets/resnet.py -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp -class BlockV1(hk.Module): - def __init__(self, num_channels, name="BlockV1"): - super(BlockV1, self).__init__(name=name) - self.num_channels = num_channels +class BlockV1(eqx.Module): + conv1: eqx.nn.Conv2d + conv2: eqx.nn.Conv2d + norm1: eqx.nn.BatchNorm + norm2: eqx.nn.BatchNorm - def __call__(self, x, is_training, test_local_stats): + def __init__(self, in_channels, out_channels, key): + keys = jax.random.split(key, 2) + self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, key=keys[1]) + self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + + def __call__(self, x, state): i = x - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) + x = self.conv1(x) + x, state = self.norm1(x, state) x = jax.nn.relu(x) - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) - return jax.nn.relu(x + i) + x = self.conv2(x) + x, state = self.norm2(x, state) + return jax.nn.relu(x + i), state + +class BlockV2(eqx.Module): + conv1: eqx.nn.Conv2d + conv2: eqx.nn.Conv2d + norm1: eqx.nn.BatchNorm + norm2: eqx.nn.BatchNorm -class BlockV2(hk.Module): - def __init__(self, num_channels, name="BlockV2"): - super(BlockV2, self).__init__(name=name) - self.num_channels = num_channels + def __init__(self, in_channels, out_channels, key): + keys = jax.random.split(key, 2) + self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, key=keys[1]) + self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) - def __call__(self, x, is_training, test_local_stats): + def __call__(self, x, state): i = x - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) - x = jax.nn.relu(x) - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) + x, state = self.norm1(x, state) x = jax.nn.relu(x) - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - return x + i + x = self.conv1(x) + x, state = self.norm2(x, state) + x = jax.nn.relu(x + i) + x = self.conv2(x) + return x + i, state -class AZNet(hk.Module): +class AZNet(eqx.Module): """AlphaZero NN architecture.""" + init_layers: list + resnet: list + post_resnet: list + policy_head: list + value_head: list + def __init__( self, num_actions, - num_channels: int = 64, + input_channels, + key, + output_channels: int = 64, num_blocks: int = 5, resnet_v2: bool = True, - name="az_net", ): - super().__init__(name=name) self.num_actions = num_actions - self.num_channels = num_channels - self.num_blocks = num_blocks - self.resnet_v2 = resnet_v2 - self.resnet_cls = BlockV2 if resnet_v2 else BlockV1 - - def __call__(self, x, is_training, test_local_stats): + resnet_cls = BlockV2 if resnet_v2 else BlockV1 + + keys = jax.random.split(key, num_blocks + 5) + self.init_layers = [eqx.nn.Conv2d(input_channels, output_channels, kernel_size=3, key=keys[0])] + if not resnet_v2: + self.init_layers += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] + self.resnet = [resnet_cls(output_channels, output_channels, keys[i + 1]) for i in range(num_blocks)] + self.post_resnet = [] + if resnet_v2: + self.post_resnet += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] + self.policy_head = [ + eqx.nn.Conv2d(output_channels, 2, kernel_size=1, key=keys[num_blocks + 1]), + eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), + jax.nn.relu, + lambda x: x.flatten(), + eqx.nn.Linear(200, self.num_actions, key=keys[num_blocks + 2]), + ] + + self.value_head = [ + eqx.nn.Conv2d(output_channels, 1, kernel_size=1, key=keys[num_blocks + 3]), + eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), + jax.nn.relu, + lambda x: x.flatten(), + eqx.nn.Linear(200, output_channels, key=keys[num_blocks + 2]), + jax.nn.relu, + eqx.nn.Linear(output_channels, 1, key=keys[num_blocks + 2]), + jnp.tanh, + jnp.squeeze, + ] + + def __call__(self, x, state): x = x.astype(jnp.float32) - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - - if not self.resnet_v2: - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) - x = jax.nn.relu(x) - - for i in range(self.num_blocks): - x = self.resnet_cls(self.num_channels, name=f"block_{i}")( - x, is_training, test_local_stats - ) - - if self.resnet_v2: - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) - x = jax.nn.relu(x) - - # policy head - logits = hk.Conv2D(output_channels=2, kernel_shape=1)(x) - logits = hk.BatchNorm(True, True, 0.9)(logits, is_training, test_local_stats) - logits = jax.nn.relu(logits) - logits = hk.Flatten()(logits) - logits = hk.Linear(self.num_actions)(logits) - - # value head - v = hk.Conv2D(output_channels=1, kernel_shape=1)(x) - v = hk.BatchNorm(True, True, 0.9)(v, is_training, test_local_stats) - v = jax.nn.relu(v) - v = hk.Flatten()(v) - v = hk.Linear(self.num_channels)(v) - v = jax.nn.relu(v) - v = hk.Linear(1)(v) - v = jnp.tanh(v) - v = v.reshape((-1,)) - - return logits, v + x = jnp.moveaxis(x, -1, 0) + + for layer in self.init_layers: + if isinstance(layer, eqx.nn.StatefulLayer): + x, state = layer(x, state) + else: + x = layer(x) + + for layer in self.resnet: + x, state = layer(x, state) + + for layer in self.post_resnet: + if isinstance(layer, eqx.nn.StatefulLayer): + x, state = layer(x, state) + else: + x = layer(x) + + logits = x.copy() + for layer in self.policy_head: + if isinstance(layer, eqx.nn.StatefulLayer): + logits, state = layer(logits, state) + else: + logits = layer(logits) + + v = x.copy() + for layer in self.value_head: + if isinstance(layer, eqx.nn.StatefulLayer): + v, state = layer(v, state) + else: + v = layer(v) + + return logits, v, state diff --git a/examples/alphazero/stateful_network.py b/examples/alphazero/stateful_network.py new file mode 100644 index 000000000..fa0f759a4 --- /dev/null +++ b/examples/alphazero/stateful_network.py @@ -0,0 +1,136 @@ +# We referred to Haiku's ResNet implementation: +# https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/nets/resnet.py + +import equinox as eqx +import jax +import jax.numpy as jnp + + +class BlockV1(eqx.Module): + conv1: eqx.nn.Conv2d + conv2: eqx.nn.Conv2d + norm1: eqx.nn.BatchNorm + norm2: eqx.nn.BatchNorm + + def __init__(self, in_channels, out_channels, key): + keys = jax.random.split(key, 2) + self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, key=keys[1]) + self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + + def __call__(self, x, state): + i = x + x = self.conv1(x) + x, state = self.norm1(x, state) + x = jax.nn.relu(x) + x = self.conv2(x) + x, state = self.norm2(x, state) + return jax.nn.relu(x + i), state + + +class BlockV2(eqx.Module): + conv1: eqx.nn.Conv2d + conv2: eqx.nn.Conv2d + norm1: eqx.nn.BatchNorm + norm2: eqx.nn.BatchNorm + + def __init__(self, in_channels, out_channels, key): + keys = jax.random.split(key, 2) + self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, key=keys[1]) + self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + + def __call__(self, x, state): + i = x + x, state = self.norm1(x, state) + x = jax.nn.relu(x) + x = self.conv1(x) + x, state = self.norm2(x, state) + x = jax.nn.relu(x + i) + x = self.conv2(x) + return x + i, state + + +class AZNet(eqx.Module): + """AlphaZero NN architecture.""" + + init_layers: list + resnet: list + post_resnet: list + policy_head: list + value_head: list + + def __init__( + self, + num_actions, + input_channels, + key, + output_channels: int = 64, + num_blocks: int = 5, + resnet_v2: bool = True, + ): + self.num_actions = num_actions + resnet_cls = BlockV2 if resnet_v2 else BlockV1 + + keys = jax.random.split(key, num_blocks + 5) + self.init_layers = [eqx.nn.Conv2d(input_channels, output_channels, kernel_size=3, key=keys[0])] + if not resnet_v2: + self.init_layers += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] + self.resnet = [resnet_cls(output_channels, output_channels, keys[i + 1]) for i in range(num_blocks)] + self.post_resnet = [] + if resnet_v2: + self.post_resnet += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] + self.policy_head = [ + eqx.nn.Conv2d(output_channels, 2, kernel_size=1, key=keys[num_blocks + 1]), + eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), + jax.nn.relu, + lambda x: x.flatten(), + eqx.nn.Linear(200, self.num_actions, key=keys[num_blocks + 2]), + ] + + self.value_head = [ + eqx.nn.Conv2d(output_channels, 1, kernel_size=1, key=keys[num_blocks + 3]), + eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), + jax.nn.relu, + lambda x: x.flatten(), + eqx.nn.Linear(200, output_channels, key=keys[num_blocks + 2]), + jax.nn.relu, + eqx.nn.Linear(output_channels, 1, key=keys[num_blocks + 2]), + jnp.tanh, + jnp.squeeze, + ] + + def __call__(self, x, state): + x = x.astype(jnp.float32) + for layer in self.init_layers: + if isinstance(layer, eqx.nn.StatefulLayer): + x, state = layer(x, state) + else: + x = layer(x) + + for layer in self.resnet: + x, state = layer(x, state) + + for layer in self.post_resnet: + if isinstance(layer, eqx.nn.StatefulLayer): + x, state = layer(x, state) + else: + x = layer(x) + + logits = x.copy() + for layer in self.policy_head: + if isinstance(layer, eqx.nn.StatefulLayer): + logits, state = layer(logits, state) + else: + logits = layer(logits) + + v = x.copy() + for layer in self.value_head: + if isinstance(layer, eqx.nn.StatefulLayer): + v, state = layer(v, state) + else: + v = layer(v) + + return logits, v, state diff --git a/examples/alphazero/train.py b/examples/alphazero/train.py index a80e5d9ec..c3d9ef608 100644 --- a/examples/alphazero/train.py +++ b/examples/alphazero/train.py @@ -19,7 +19,6 @@ from functools import partial from typing import NamedTuple -import haiku as hk import jax import jax.numpy as jnp import mctx @@ -29,6 +28,7 @@ from omegaconf import OmegaConf from pgx.experimental import auto_reset from pydantic import BaseModel +import equinox as eqx from network import AZNet @@ -65,19 +65,6 @@ class Config: env = pgx.make(config.env_id) baseline = pgx.make_baseline_model(config.env_id + "_v0") - -def forward_fn(x, is_eval=False): - net = AZNet( - num_actions=env.num_actions, - num_channels=config.num_channels, - num_blocks=config.num_layers, - resnet_v2=config.resnet_v2, - ) - policy_out, value_out = net(x, is_training=not is_eval, test_local_stats=False) - return policy_out, value_out - - -forward = hk.without_apply_rng(hk.transform_with_state(forward_fn)) optimizer = optax.adam(learning_rate=config.learning_rate) @@ -264,10 +251,13 @@ def body_fn(val): # Initialize model and opt_state dummy_state = jax.vmap(env.init)(jax.random.split(jax.random.PRNGKey(0), 2)) dummy_input = dummy_state.observation - model = forward.init(jax.random.PRNGKey(0), dummy_input) # (params, state) - opt_state = optimizer.init(params=model[0]) + model = AZNet(env.num_actions, env.observation_shape[-1], jax.random.key(config.seed), config.num_channels, config.num_layers, config.resnet_v2) + opt_state = optimizer.init(eqx.filter(model, eqx.is_array)) # replicates to all devices - model, opt_state = jax.device_put_replicated((model, opt_state), devices) + arr, static = eqx.filter((model, opt_state), eqx.is_array) + train_model, opt_state = jax.device_put_replicated(arr, devices) + train_model = eqx.combine(train_model, static[0]) + opt_state = eqx.combine(opt_state, static[1]) # Prepare checkpoint dir now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=9))) @@ -281,7 +271,7 @@ def body_fn(val): frames: int = 0 log = {"iteration": iteration, "hours": hours, "frames": frames} - rng_key = jax.random.PRNGKey(config.seed) + rng_key = jax.random.key(config.seed) while True: if iteration % config.eval_interval == 0: # Evaluation diff --git a/examples/minatar-ppo/requirements.txt b/examples/minatar-ppo/requirements.txt index af303b4c6..2020f983b 100644 --- a/examples/minatar-ppo/requirements.txt +++ b/examples/minatar-ppo/requirements.txt @@ -1,7 +1,7 @@ pgx>=2.0.0 -dm-haiku +equinox optax -distrax +distreqx omegaconf pydantic wandb diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index 173518c09..19feeb441 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -7,10 +7,10 @@ import sys import jax import jax.numpy as jnp -import haiku as hk +import equinox as eqx import optax from typing import NamedTuple, Literal -import distrax +from distreqx import distributions import pgx from pgx.experimental import auto_reset import time @@ -31,12 +31,12 @@ class PPOConfig(BaseModel): ] = "minatar-breakout" seed: int = 0 lr: float = 0.0003 - num_envs: int = 4096 + num_envs: int = 40 num_eval_envs: int = 100 num_steps: int = 128 total_timesteps: int = 20000000 update_epochs: int = 3 - minibatch_size: int = 4096 + minibatch_size: int = 40 gamma: float = 0.99 gae_lambda: float = 0.95 clip_eps: float = 0.2 @@ -58,51 +58,61 @@ class Config: num_updates = args.total_timesteps // args.num_envs // args.num_steps num_minibatches = args.num_envs * args.num_steps // args.minibatch_size +class ActorCritic(eqx.Module): + features: list + actor: list + critic: list -class ActorCritic(hk.Module): - def __init__(self, num_actions, activation="tanh"): - super().__init__() - self.num_actions = num_actions - self.activation = activation + def __init__(self, num_actions, key, activation="tanh"): assert activation in ["relu", "tanh"] + if activation == "relu": + act_fn = jax.nn.relu + else: + act_fn = jax.nn.tanh + + keys = jax.random.split(key, 8) + self.features = [ + eqx.nn.Conv2d(env.observation_shape[2], 32, 2, key=keys[0]), + # (4, 10, 10) -> (32, 9, 9) + jax.nn.relu, + eqx.nn.AvgPool2d(2, 2), + # (32, 9, 9) -> (32, 4, 4) + lambda x: x.flatten(), + eqx.nn.Linear(32 * 4 * 4, 64, key=keys[1]), + jax.nn.relu, + ] + + self.actor = [ + eqx.nn.Linear(64, 64, key=keys[2]), + act_fn, + eqx.nn.Linear(64, 64, key=keys[3]), + act_fn, + eqx.nn.Linear(64, num_actions, key=keys[4]), + ] + + self.critic = [ + eqx.nn.Linear(64, 64, key=keys[5]), + act_fn, + eqx.nn.Linear(64, 64, key=keys[6]), + act_fn, + eqx.nn.Linear(64, 1, key=keys[7]), + ] def __call__(self, x): x = x.astype(jnp.float32) - if self.activation == "relu": - activation = jax.nn.relu - else: - activation = jax.nn.tanh - x = hk.Conv2D(32, kernel_shape=2)(x) - x = jax.nn.relu(x) - x = hk.avg_pool(x, window_shape=(2, 2), - strides=(2, 2), padding="VALID") - x = x.reshape((x.shape[0], -1)) # flatten - x = hk.Linear(64)(x) - x = jax.nn.relu(x) - actor_mean = hk.Linear(64)(x) - actor_mean = activation(actor_mean) - actor_mean = hk.Linear(64)(actor_mean) - actor_mean = activation(actor_mean) - actor_mean = hk.Linear(self.num_actions)(actor_mean) - - critic = hk.Linear(64)(x) - critic = activation(critic) - critic = hk.Linear(64)(critic) - critic = activation(critic) - critic = hk.Linear(1)(critic) - + # make channels first + x = jnp.moveaxis(x, -1, 0) + for layer in self.features: + x = layer(x) + actor_mean = jnp.copy(x) + for layer in self.actor: + actor_mean = layer(actor_mean) + critic = jnp.copy(x) + for layer in self.critic: + critic = layer(critic) return actor_mean, jnp.squeeze(critic, axis=-1) -def forward_fn(x, is_eval=False): - net = ActorCritic(env.num_actions, activation="tanh") - logits, value = net(x) - return logits, value - - -forward = hk.without_apply_rng(hk.transform(forward_fn)) - - optimizer = optax.chain(optax.clip_by_global_norm( args.max_grad_norm), optax.adam(args.lr, eps=1e-5)) @@ -122,14 +132,20 @@ def _update_step(runner_state): # COLLECT TRAJECTORIES step_fn = jax.vmap(auto_reset(env.step, env.init)) + arrs, static = eqx.partition(runner_state[0], eqx.is_array) + runner_state = eqx.tree_at(lambda x: x[0], runner_state, arrs) + def _env_step(runner_state, unused): - params, opt_state, env_state, last_obs, rng = runner_state + arr_params, opt_state, env_state, last_obs, rng = runner_state + params = eqx.combine(arr_params, static) # SELECT ACTION rng, _rng = jax.random.split(rng) - logits, value = forward.apply(params, last_obs) - pi = distrax.Categorical(logits=logits) - action = pi.sample(seed=_rng) - log_prob = pi.log_prob(action) + __rng = jax.random.split(_rng, last_obs.shape[0]) + logits, value = eqx.filter_vmap(params)(last_obs) + pi = eqx.filter_vmap(distributions.Categorical)(logits) + action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) + action = action.astype('int32') + log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, action) # STEP ENV rng, _rng = jax.random.split(rng) @@ -143,17 +159,18 @@ def _env_step(runner_state, unused): log_prob, last_obs ) - runner_state = (params, opt_state, env_state, + runner_state = (arr_params, opt_state, env_state, env_state.observation, rng) return runner_state, transition runner_state, traj_batch = jax.lax.scan( _env_step, runner_state, None, args.num_steps ) + runner_state = eqx.tree_at(lambda x: x[0], runner_state, eqx.combine(runner_state[0], static)) # CALCULATE ADVANTAGE params, opt_state, env_state, last_obs, rng = runner_state - _, last_val = forward.apply(params, last_obs) + _, last_val = eqx.filter_vmap(params)(last_obs) def _calculate_gae(traj_batch, last_val): def _get_advantages(gae_and_next_value, transition): @@ -181,6 +198,8 @@ def _get_advantages(gae_and_next_value, transition): advantages, targets = _calculate_gae(traj_batch, last_val) + params_arr, static = eqx.partition(params, eqx.is_array) + # UPDATE NETWORK def _update_epoch(update_state, unused): def _update_minbatch(tup, batch_info): @@ -189,9 +208,9 @@ def _update_minbatch(tup, batch_info): def _loss_fn(params, traj_batch, gae, targets): # RERUN NETWORK - logits, value = forward.apply(params, traj_batch.obs) - pi = distrax.Categorical(logits=logits) - log_prob = pi.log_prob(traj_batch.action) + logits, value = eqx.filter_vmap(params)(traj_batch.obs) + pi = eqx.filter_vmap(distributions.Categorical)(logits) + log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, traj_batch.action) # CALCULATE VALUE LOSS value_pred_clipped = traj_batch.value + ( @@ -228,11 +247,11 @@ def _loss_fn(params, traj_batch, gae, targets): ) return total_loss, (value_loss, loss_actor, entropy) - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + grad_fn = eqx.filter_value_and_grad(_loss_fn, has_aux=True) total_loss, grads = grad_fn( - params, traj_batch, advantages, targets) + eqx.combine(params, static), traj_batch, advantages, targets) updates, opt_state = optimizer.update(grads, opt_state) - params = optax.apply_updates(params, updates) + params = eqx.apply_updates(params, updates) return (params, opt_state), total_loss params, opt_state, traj_batch, advantages, targets, rng = update_state @@ -262,19 +281,19 @@ def _loss_fn(params, traj_batch, gae, targets): advantages, targets, rng) return update_state, total_loss - update_state = (params, opt_state, traj_batch, + update_state = (params_arr, opt_state, traj_batch, advantages, targets, rng) update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, args.update_epochs ) params, opt_state, _, _, _, rng = update_state - runner_state = (params, opt_state, env_state, last_obs, rng) + runner_state = (eqx.combine(params, static), opt_state, env_state, last_obs, rng) return runner_state, loss_info return _update_step -@jax.jit +@eqx.filter_jit def evaluate(params, rng_key): step_fn = jax.vmap(env.step) rng_key, sub_key = jax.random.split(rng_key) @@ -288,11 +307,13 @@ def cond_fn(tup): def loop_fn(tup): state, R, rng_key = tup - logits, value = forward.apply(params, state.observation) + logits, value = eqx.filter_vmap(params)(state.observation) # action = logits.argmax(axis=-1) - pi = distrax.Categorical(logits=logits) + pi = eqx.filter_vmap(distributions.Categorical)(logits) rng_key, _rng = jax.random.split(rng_key) - action = pi.sample(seed=_rng) + __rng = jax.random.split(_rng, state.observation.shape[0]) + action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) + action = action.astype('int32') rng_key, _rng = jax.random.split(rng_key) keys = jax.random.split(_rng, state.observation.shape[0]) state = step_fn(state, action, keys) @@ -306,13 +327,12 @@ def train(rng): st = time.time() # INIT NETWORK rng, _rng = jax.random.split(rng) - init_x = jnp.zeros((1, ) + env.observation_shape) - params = forward.init(_rng, init_x) - opt_state = optimizer.init(params=params) + model = ActorCritic(env.num_actions, _rng, "tanh") + opt_state = optimizer.init(params=eqx.filter(model, eqx.is_inexact_array)) # INIT UPDATE FUNCTION _update_step = make_update_fn() - jitted_update_step = jax.jit(_update_step) + jitted_update_step = eqx.filter_jit(_update_step) # INIT ENV rng, _rng = jax.random.split(rng) @@ -320,7 +340,7 @@ def train(rng): env_state = jax.jit(jax.vmap(env.init))(reset_rng) rng, _rng = jax.random.split(rng) - runner_state = (params, opt_state, env_state, env_state.observation, _rng) + runner_state = (model, opt_state, env_state, env_state.observation, _rng) # warm up _, _ = jitted_update_step(runner_state) @@ -334,7 +354,7 @@ def train(rng): eval_R = evaluate(runner_state[0], _rng) log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps} print(log) - wandb.log(log) + # wandb.log(log) st = time.time() for i in range(num_updates): @@ -348,14 +368,14 @@ def train(rng): eval_R = evaluate(runner_state[0], _rng) log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps} print(log) - wandb.log(log) + # wandb.log(log) st = time.time() return runner_state if __name__ == "__main__": - wandb.init(project=args.wandb_project, config=args.dict()) + # wandb.init(project=args.wandb_project, config=args.dict()) rng = jax.random.PRNGKey(args.seed) out = train(rng) if args.save_model: diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 404dbb471..c67fa62f5 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -3,8 +3,7 @@ pytest pytest-xdist matplotlib ipython -# hot fix. to avoid errors in Py3.8 -dm-haiku==0.0.10 +equinox pytest-cov pgx-minatar black From e571083f0d9c80c381cabfb8ba26740b0f550107 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:27:03 -0800 Subject: [PATCH 02/17] work --- docs/api_usage.md | 2 +- examples/alphazero/network.py | 26 ++--- examples/alphazero/stateful_network.py | 136 ------------------------- examples/alphazero/train.py | 133 +++++++++++++----------- examples/minatar-ppo/train.py | 6 +- 5 files changed, 93 insertions(+), 210 deletions(-) delete mode 100644 examples/alphazero/stateful_network.py diff --git a/docs/api_usage.md b/docs/api_usage.md index c3a877bec..39a15b030 100644 --- a/docs/api_usage.md +++ b/docs/api_usage.md @@ -68,7 +68,7 @@ init_fn = jax.jit(jax.vmap(env.init)) step_fn = jax.jit(jax.vmap(env.step)) # Prepare baseline model -# Note that it additionaly requires equinox library ($ pip install equinox) +# Note that it additionaly requires equinox library ($ pip install dm-haiku) model_id = "go_9x9_v0" model = pgx.make_baseline_model(model_id) diff --git a/examples/alphazero/network.py b/examples/alphazero/network.py index 4466f74e8..67d68909d 100644 --- a/examples/alphazero/network.py +++ b/examples/alphazero/network.py @@ -14,8 +14,8 @@ class BlockV1(eqx.Module): def __init__(self, in_channels, out_channels, key): keys = jax.random.split(key, 2) - self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, key=keys[0]) - self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, key=keys[1]) + self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, padding="SAME", kernel_size=3, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, padding="SAME", kernel_size=3, key=keys[1]) self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) @@ -37,9 +37,9 @@ class BlockV2(eqx.Module): def __init__(self, in_channels, out_channels, key): keys = jax.random.split(key, 2) - self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, key=keys[0]) - self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, key=keys[1]) - self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, padding="SAME", kernel_size=3, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, padding="SAME", kernel_size=3, key=keys[1]) + self.norm1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9) self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) def __call__(self, x, state): @@ -48,7 +48,7 @@ def __call__(self, x, state): x = jax.nn.relu(x) x = self.conv1(x) x, state = self.norm2(x, state) - x = jax.nn.relu(x + i) + x = jax.nn.relu(x) x = self.conv2(x) return x + i, state @@ -71,7 +71,6 @@ def __init__( num_blocks: int = 5, resnet_v2: bool = True, ): - self.num_actions = num_actions resnet_cls = BlockV2 if resnet_v2 else BlockV1 keys = jax.random.split(key, num_blocks + 5) @@ -84,18 +83,19 @@ def __init__( self.post_resnet += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] self.policy_head = [ eqx.nn.Conv2d(output_channels, 2, kernel_size=1, key=keys[num_blocks + 1]), - eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), + eqx.nn.BatchNorm(2, "batch", momentum=0.9), jax.nn.relu, lambda x: x.flatten(), - eqx.nn.Linear(200, self.num_actions, key=keys[num_blocks + 2]), + # TODO: infer 98 from inputs + eqx.nn.Linear(98, num_actions, key=keys[num_blocks + 2]), ] self.value_head = [ eqx.nn.Conv2d(output_channels, 1, kernel_size=1, key=keys[num_blocks + 3]), - eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), + eqx.nn.BatchNorm(1, "batch", momentum=0.9), jax.nn.relu, lambda x: x.flatten(), - eqx.nn.Linear(200, output_channels, key=keys[num_blocks + 2]), + eqx.nn.Linear(49, output_channels, key=keys[num_blocks + 2]), jax.nn.relu, eqx.nn.Linear(output_channels, 1, key=keys[num_blocks + 2]), jnp.tanh, @@ -105,7 +105,7 @@ def __init__( def __call__(self, x, state): x = x.astype(jnp.float32) x = jnp.moveaxis(x, -1, 0) - + for layer in self.init_layers: if isinstance(layer, eqx.nn.StatefulLayer): x, state = layer(x, state) @@ -135,4 +135,4 @@ def __call__(self, x, state): else: v = layer(v) - return logits, v, state + return (logits, v), state diff --git a/examples/alphazero/stateful_network.py b/examples/alphazero/stateful_network.py deleted file mode 100644 index fa0f759a4..000000000 --- a/examples/alphazero/stateful_network.py +++ /dev/null @@ -1,136 +0,0 @@ -# We referred to Haiku's ResNet implementation: -# https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/nets/resnet.py - -import equinox as eqx -import jax -import jax.numpy as jnp - - -class BlockV1(eqx.Module): - conv1: eqx.nn.Conv2d - conv2: eqx.nn.Conv2d - norm1: eqx.nn.BatchNorm - norm2: eqx.nn.BatchNorm - - def __init__(self, in_channels, out_channels, key): - keys = jax.random.split(key, 2) - self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, key=keys[0]) - self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, key=keys[1]) - self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) - self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) - - def __call__(self, x, state): - i = x - x = self.conv1(x) - x, state = self.norm1(x, state) - x = jax.nn.relu(x) - x = self.conv2(x) - x, state = self.norm2(x, state) - return jax.nn.relu(x + i), state - - -class BlockV2(eqx.Module): - conv1: eqx.nn.Conv2d - conv2: eqx.nn.Conv2d - norm1: eqx.nn.BatchNorm - norm2: eqx.nn.BatchNorm - - def __init__(self, in_channels, out_channels, key): - keys = jax.random.split(key, 2) - self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, key=keys[0]) - self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, key=keys[1]) - self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) - self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) - - def __call__(self, x, state): - i = x - x, state = self.norm1(x, state) - x = jax.nn.relu(x) - x = self.conv1(x) - x, state = self.norm2(x, state) - x = jax.nn.relu(x + i) - x = self.conv2(x) - return x + i, state - - -class AZNet(eqx.Module): - """AlphaZero NN architecture.""" - - init_layers: list - resnet: list - post_resnet: list - policy_head: list - value_head: list - - def __init__( - self, - num_actions, - input_channels, - key, - output_channels: int = 64, - num_blocks: int = 5, - resnet_v2: bool = True, - ): - self.num_actions = num_actions - resnet_cls = BlockV2 if resnet_v2 else BlockV1 - - keys = jax.random.split(key, num_blocks + 5) - self.init_layers = [eqx.nn.Conv2d(input_channels, output_channels, kernel_size=3, key=keys[0])] - if not resnet_v2: - self.init_layers += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] - self.resnet = [resnet_cls(output_channels, output_channels, keys[i + 1]) for i in range(num_blocks)] - self.post_resnet = [] - if resnet_v2: - self.post_resnet += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] - self.policy_head = [ - eqx.nn.Conv2d(output_channels, 2, kernel_size=1, key=keys[num_blocks + 1]), - eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), - jax.nn.relu, - lambda x: x.flatten(), - eqx.nn.Linear(200, self.num_actions, key=keys[num_blocks + 2]), - ] - - self.value_head = [ - eqx.nn.Conv2d(output_channels, 1, kernel_size=1, key=keys[num_blocks + 3]), - eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), - jax.nn.relu, - lambda x: x.flatten(), - eqx.nn.Linear(200, output_channels, key=keys[num_blocks + 2]), - jax.nn.relu, - eqx.nn.Linear(output_channels, 1, key=keys[num_blocks + 2]), - jnp.tanh, - jnp.squeeze, - ] - - def __call__(self, x, state): - x = x.astype(jnp.float32) - for layer in self.init_layers: - if isinstance(layer, eqx.nn.StatefulLayer): - x, state = layer(x, state) - else: - x = layer(x) - - for layer in self.resnet: - x, state = layer(x, state) - - for layer in self.post_resnet: - if isinstance(layer, eqx.nn.StatefulLayer): - x, state = layer(x, state) - else: - x = layer(x) - - logits = x.copy() - for layer in self.policy_head: - if isinstance(layer, eqx.nn.StatefulLayer): - logits, state = layer(logits, state) - else: - logits = layer(logits) - - v = x.copy() - for layer in self.value_head: - if isinstance(layer, eqx.nn.StatefulLayer): - v, state = layer(v, state) - else: - v = layer(v) - - return logits, v, state diff --git a/examples/alphazero/train.py b/examples/alphazero/train.py index c3d9ef608..57477bad8 100644 --- a/examples/alphazero/train.py +++ b/examples/alphazero/train.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +# import os + +# os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(4) + + import datetime import os import pickle @@ -68,34 +73,6 @@ class Config: optimizer = optax.adam(learning_rate=config.learning_rate) -def recurrent_fn(model, rng_key: jnp.ndarray, action: jnp.ndarray, state: pgx.State): - # model: params - # state: embedding - del rng_key - model_params, model_state = model - - current_player = state.current_player - state = jax.vmap(env.step)(state, action) - - (logits, value), _ = forward.apply(model_params, model_state, state.observation, is_eval=True) - # mask invalid actions - logits = logits - jnp.max(logits, axis=-1, keepdims=True) - logits = jnp.where(state.legal_action_mask, logits, jnp.finfo(logits.dtype).min) - - reward = state.rewards[jnp.arange(state.rewards.shape[0]), current_player] - value = jnp.where(state.terminated, 0.0, value) - discount = -1.0 * jnp.ones_like(value) - discount = jnp.where(state.terminated, 0.0, discount) - - recurrent_fn_output = mctx.RecurrentFnOutput( - reward=reward, - discount=discount, - prior_logits=logits, - value=value, - ) - return recurrent_fn_output, state - - class SelfplayOutput(NamedTuple): obs: jnp.ndarray reward: jnp.ndarray @@ -104,22 +81,55 @@ class SelfplayOutput(NamedTuple): discount: jnp.ndarray -@jax.pmap +@partial(eqx.filter_pmap, in_axes=(None, 0)) def selfplay(model, rng_key: jnp.ndarray) -> SelfplayOutput: model_params, model_state = model + model_params = eqx.nn.inference_mode(model_params) + model = (model_params, model_state) + arr, static = eqx.partition(model, eqx.is_array) + + def recurrent_fn(model, rng_key: jnp.ndarray, action: jnp.ndarray, state: pgx.State): + del rng_key + model = eqx.combine(model, static) + model_params, model_state = model + + current_player = state.current_player + state = jax.vmap(env.step)(state, action) + + # (logits, value), _ = forward.apply(model_params, model_state, state.observation, is_eval=True) + (logits, value), _ = eqx.filter_vmap(model_params, in_axes=(0, None), out_axes=(0, None), axis_name="batch")( + state.observation, model_state + ) + # mask invalid actions + logits = logits - jnp.max(logits, axis=-1, keepdims=True) + logits = jnp.where(state.legal_action_mask, logits, jnp.finfo(logits.dtype).min) + + reward = state.rewards[jnp.arange(state.rewards.shape[0]), current_player] + value = jnp.where(state.terminated, 0.0, value) + discount = -1.0 * jnp.ones_like(value) + discount = jnp.where(state.terminated, 0.0, discount) + + recurrent_fn_output = mctx.RecurrentFnOutput( + reward=reward, + discount=discount, + prior_logits=logits, + value=value, + ) + return recurrent_fn_output, state + batch_size = config.selfplay_batch_size // num_devices def step_fn(state, key) -> SelfplayOutput: key1, key2 = jax.random.split(key) observation = state.observation - (logits, value), _ = forward.apply( - model_params, model_state, state.observation, is_eval=True + (logits, value), _ = eqx.filter_vmap(model_params, in_axes=(0, None), out_axes=(0, None), axis_name="batch")( + state.observation, model_state ) root = mctx.RootFnOutput(prior_logits=logits, value=value, embedding=state) policy_output = mctx.gumbel_muzero_policy( - params=model, + params=arr, rng_key=key1, root=root, recurrent_fn=recurrent_fn, @@ -187,9 +197,9 @@ def body_fn(carry, i): def loss_fn(model_params, model_state, samples: Sample): - (logits, value), model_state = forward.apply( - model_params, model_state, samples.obs, is_eval=False - ) + (logits, value), model_state = eqx.filter_vmap( + model_params, in_axes=(0, None), out_axes=(0, None), axis_name="batch" + )(samples.obs, model_state) policy_loss = optax.softmax_cross_entropy(logits, samples.policy_tgt) policy_loss = jnp.mean(policy_loss) @@ -200,25 +210,26 @@ def loss_fn(model_params, model_state, samples: Sample): return policy_loss + value_loss, (model_state, policy_loss, value_loss) -@partial(jax.pmap, axis_name="i") +@partial(eqx.filter_pmap, axis_name="i", in_axes=(None, None, 0), out_axes=(None, None, 0, 0)) def train(model, opt_state, data: Sample): model_params, model_state = model - grads, (model_state, policy_loss, value_loss) = jax.grad(loss_fn, has_aux=True)( + grads, (model_state, policy_loss, value_loss) = eqx.filter_grad(loss_fn, has_aux=True)( model_params, model_state, data ) grads = jax.lax.pmean(grads, axis_name="i") updates, opt_state = optimizer.update(grads, opt_state) - model_params = optax.apply_updates(model_params, updates) + model_params = eqx.apply_updates(model_params, updates) model = (model_params, model_state) return model, opt_state, policy_loss, value_loss -@jax.pmap +@partial(eqx.filter_pmap, in_axes=(0, None)) def evaluate(rng_key, my_model): - """A simplified evaluation by sampling. Only for debugging. + """A simplified evaluation by sampling. Only for debugging. Please use MCTS and run tournaments for serious evaluation.""" my_player = 0 - my_model_params, my_model_state = my_model + my_model, my_model_state = my_model + inference_model = eqx.nn.inference_mode(my_model) key, subkey = jax.random.split(rng_key) batch_size = config.selfplay_batch_size // num_devices @@ -227,8 +238,8 @@ def evaluate(rng_key, my_model): def body_fn(val): key, state, R = val - (my_logits, _), _ = forward.apply( - my_model_params, my_model_state, state.observation, is_eval=True + (my_logits, _), _ = eqx.filter_vmap(inference_model, in_axes=(0, None), out_axes=(0, None), axis_name="batch")( + state.observation, my_model_state ) opp_logits, _ = baseline(state.observation) is_my_turn = (state.current_player == my_player).reshape((-1, 1)) @@ -239,9 +250,7 @@ def body_fn(val): R = R + state.rewards[jnp.arange(batch_size), my_player] return (key, state, R) - _, _, R = jax.lax.while_loop( - lambda x: ~(x[1].terminated.all()), body_fn, (key, state, jnp.zeros(batch_size)) - ) + _, _, R = jax.lax.while_loop(lambda x: ~(x[1].terminated.all()), body_fn, (key, state, jnp.zeros(batch_size))) return R @@ -251,13 +260,23 @@ def body_fn(val): # Initialize model and opt_state dummy_state = jax.vmap(env.init)(jax.random.split(jax.random.PRNGKey(0), 2)) dummy_input = dummy_state.observation - model = AZNet(env.num_actions, env.observation_shape[-1], jax.random.key(config.seed), config.num_channels, config.num_layers, config.resnet_v2) - opt_state = optimizer.init(eqx.filter(model, eqx.is_array)) + init_model, state = eqx.nn.make_with_state(AZNet)( + env.num_actions, + env.observation_shape[-1], + jax.random.key(config.seed), + config.num_channels, + config.num_layers, + config.resnet_v2, + ) + opt_state = optimizer.init(eqx.filter(init_model, eqx.is_array)) # replicates to all devices - arr, static = eqx.filter((model, opt_state), eqx.is_array) - train_model, opt_state = jax.device_put_replicated(arr, devices) - train_model = eqx.combine(train_model, static[0]) - opt_state = eqx.combine(opt_state, static[1]) + arr, static = eqx.partition((init_model, opt_state, state), eqx.is_array) + # arr_rep = jax.device_put_replicated(arr, devices) + arr_rep = arr + train_model = eqx.combine(arr_rep[0], static[0]) + opt_state = eqx.combine(arr_rep[1], static[1]) + state = eqx.combine(arr_rep[2], static[2]) + model = (train_model, state) # Prepare checkpoint dir now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=9))) @@ -288,7 +307,8 @@ def body_fn(val): ) # Store checkpoints - model_0, opt_state_0 = jax.tree_util.tree_map(lambda x: x[0], (model, opt_state)) + # model_0, opt_state_0 = jax.tree_util.tree_map(lambda x: x[0], (train_model, opt_state)) + model_0, opt_state_0 = eqx.filter((model[0], opt_state), eqx.is_array) with open(os.path.join(ckpt_dir, f"{iteration:06d}.ckpt"), "wb") as f: dic = { "config": config, @@ -302,7 +322,7 @@ def body_fn(val): "env_id": env.id, "env_version": env.version, } - pickle.dump(dic, f) + # pickle.dump(dic, f) print(log) wandb.log(log) @@ -328,9 +348,7 @@ def body_fn(val): ixs = jax.random.permutation(subkey, jnp.arange(samples.obs.shape[0])) samples = jax.tree_util.tree_map(lambda x: x[ixs], samples) # shuffle num_updates = samples.obs.shape[0] // config.training_batch_size - minibatches = jax.tree_util.tree_map( - lambda x: x.reshape((num_updates, num_devices, -1) + x.shape[1:]), samples - ) + minibatches = jax.tree_util.tree_map(lambda x: x.reshape((num_updates, num_devices, -1) + x.shape[1:]), samples) # Training policy_losses, value_losses = [], [] @@ -339,6 +357,7 @@ def body_fn(val): model, opt_state, policy_loss, value_loss = train(model, opt_state, minibatch) policy_losses.append(policy_loss.mean().item()) value_losses.append(value_loss.mean().item()) + policy_loss = sum(policy_losses) / len(policy_losses) value_loss = sum(value_losses) / len(value_losses) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index 19feeb441..251189736 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -354,7 +354,7 @@ def train(rng): eval_R = evaluate(runner_state[0], _rng) log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps} print(log) - # wandb.log(log) + wandb.log(log) st = time.time() for i in range(num_updates): @@ -368,14 +368,14 @@ def train(rng): eval_R = evaluate(runner_state[0], _rng) log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps} print(log) - # wandb.log(log) + wandb.log(log) st = time.time() return runner_state if __name__ == "__main__": - # wandb.init(project=args.wandb_project, config=args.dict()) + wandb.init(project=args.wandb_project, config=args.dict()) rng = jax.random.PRNGKey(args.seed) out = train(rng) if args.save_model: From de12271c5102d99e02b437e52c866bf993ee69be Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 4 Feb 2025 18:13:06 -0800 Subject: [PATCH 03/17] distrax --- examples/minatar-ppo/train.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index 251189736..6ae69294c 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -11,6 +11,7 @@ import optax from typing import NamedTuple, Literal from distreqx import distributions +import distrax import pgx from pgx.experimental import auto_reset import time @@ -31,12 +32,12 @@ class PPOConfig(BaseModel): ] = "minatar-breakout" seed: int = 0 lr: float = 0.0003 - num_envs: int = 40 + num_envs: int = 4096 num_eval_envs: int = 100 num_steps: int = 128 total_timesteps: int = 20000000 update_epochs: int = 3 - minibatch_size: int = 40 + minibatch_size: int = 4096 gamma: float = 0.99 gae_lambda: float = 0.95 clip_eps: float = 0.2 @@ -142,10 +143,13 @@ def _env_step(runner_state, unused): rng, _rng = jax.random.split(rng) __rng = jax.random.split(_rng, last_obs.shape[0]) logits, value = eqx.filter_vmap(params)(last_obs) - pi = eqx.filter_vmap(distributions.Categorical)(logits) - action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) - action = action.astype('int32') - log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, action) + pi = distrax.Categorical(logits=logits) + action = pi.sample(seed=_rng) + log_prob = pi.log_prob(action) + # pi = eqx.filter_vmap(distributions.Categorical)(logits) + # action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) + # action = action.astype('int32') + # log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, action) # STEP ENV rng, _rng = jax.random.split(rng) @@ -209,8 +213,10 @@ def _update_minbatch(tup, batch_info): def _loss_fn(params, traj_batch, gae, targets): # RERUN NETWORK logits, value = eqx.filter_vmap(params)(traj_batch.obs) - pi = eqx.filter_vmap(distributions.Categorical)(logits) - log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, traj_batch.action) + pi = distrax.Categorical(logits=logits) + log_prob = pi.log_prob(traj_batch.action) + # pi = eqx.filter_vmap(distributions.Categorical)(logits) + # log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, traj_batch.action) # CALCULATE VALUE LOSS value_pred_clipped = traj_batch.value + ( @@ -309,11 +315,14 @@ def loop_fn(tup): state, R, rng_key = tup logits, value = eqx.filter_vmap(params)(state.observation) # action = logits.argmax(axis=-1) - pi = eqx.filter_vmap(distributions.Categorical)(logits) + pi = distrax.Categorical(logits=logits) rng_key, _rng = jax.random.split(rng_key) - __rng = jax.random.split(_rng, state.observation.shape[0]) - action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) - action = action.astype('int32') + action = pi.sample(seed=_rng) + # pi = eqx.filter_vmap(distributions.Categorical)(logits) + # rng_key, _rng = jax.random.split(rng_key) + # __rng = jax.random.split(_rng, state.observation.shape[0]) + # action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) + # action = action.astype('int32') rng_key, _rng = jax.random.split(rng_key) keys = jax.random.split(_rng, state.observation.shape[0]) state = step_fn(state, action, keys) From 733ecd1351a885c908a5d7dd0582a905dd07bed8 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 4 Feb 2025 19:01:54 -0800 Subject: [PATCH 04/17] init --- docs/api_usage.md | 2 +- examples/minatar-ppo/train.py | 145 +++++++++++++++++----------------- 2 files changed, 72 insertions(+), 75 deletions(-) diff --git a/docs/api_usage.md b/docs/api_usage.md index 39a15b030..f8e647d11 100644 --- a/docs/api_usage.md +++ b/docs/api_usage.md @@ -68,7 +68,7 @@ init_fn = jax.jit(jax.vmap(env.init)) step_fn = jax.jit(jax.vmap(env.step)) # Prepare baseline model -# Note that it additionaly requires equinox library ($ pip install dm-haiku) +# Note that it additionaly requires Haiku library ($ pip install dm-haiku) model_id = "go_9x9_v0" model = pgx.make_baseline_model(model_id) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index 6ae69294c..598194ed6 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -59,6 +59,34 @@ class Config: num_updates = args.total_timesteps // args.num_envs // args.num_steps num_minibatches = args.num_envs * args.num_steps // args.minibatch_size + +def init_weight(layer, key): + def where(m): + return m.weight + + s = layer.weight.shape + if len(s) == 2: + f = s[1] + else: + f = s[1] * s[2] * s[3] + return eqx.tree_at(where, layer, (1.0 / jnp.sqrt(f)) * jax.random.truncated_normal(key, -2.0, 2.0, s)) + + +def init_bias(layer): + def where(m): + return m.bias + + if layer.bias is not None: + return eqx.tree_at(where, layer, jnp.zeros_like(layer.bias)) + return layer + + +def truncated_normal_init(layer, key): + layer = init_weight(layer, key) + layer = init_bias(layer) + return layer + + class ActorCritic(eqx.Module): features: list actor: list @@ -70,34 +98,34 @@ def __init__(self, num_actions, key, activation="tanh"): act_fn = jax.nn.relu else: act_fn = jax.nn.tanh - + keys = jax.random.split(key, 8) self.features = [ - eqx.nn.Conv2d(env.observation_shape[2], 32, 2, key=keys[0]), - # (4, 10, 10) -> (32, 9, 9) - jax.nn.relu, - eqx.nn.AvgPool2d(2, 2), - # (32, 9, 9) -> (32, 4, 4) - lambda x: x.flatten(), - eqx.nn.Linear(32 * 4 * 4, 64, key=keys[1]), - jax.nn.relu, - ] + truncated_normal_init(eqx.nn.Conv2d(env.observation_shape[2], 32, 2, key=keys[0]), keys[0]), + # (4, 10, 10) -> (32, 9, 9) + jax.nn.relu, + eqx.nn.AvgPool2d(2, 2), + # (32, 9, 9) -> (32, 4, 4) + lambda x: x.flatten(), + truncated_normal_init(eqx.nn.Linear(32 * 4 * 4, 64, key=keys[1]), key=keys[1]), + jax.nn.relu, + ] self.actor = [ - eqx.nn.Linear(64, 64, key=keys[2]), - act_fn, - eqx.nn.Linear(64, 64, key=keys[3]), - act_fn, - eqx.nn.Linear(64, num_actions, key=keys[4]), - ] + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[2]), keys[2]), + act_fn, + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[3]), keys[3]), + act_fn, + truncated_normal_init(eqx.nn.Linear(64, num_actions, key=keys[4]), keys[4]), + ] self.critic = [ - eqx.nn.Linear(64, 64, key=keys[5]), - act_fn, - eqx.nn.Linear(64, 64, key=keys[6]), - act_fn, - eqx.nn.Linear(64, 1, key=keys[7]), - ] + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[5]), keys[5]), + act_fn, + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[6]), keys[6]), + act_fn, + truncated_normal_init(eqx.nn.Linear(64, 1, key=keys[7]), keys[7]), + ] def __call__(self, x): x = x.astype(jnp.float32) @@ -114,8 +142,7 @@ def __call__(self, x): return actor_mean, jnp.squeeze(critic, axis=-1) -optimizer = optax.chain(optax.clip_by_global_norm( - args.max_grad_norm), optax.adam(args.lr, eps=1e-5)) +optimizer = optax.chain(optax.clip_by_global_norm(args.max_grad_norm), optax.adam(args.lr, eps=1e-5)) class Transition(NamedTuple): @@ -156,20 +183,12 @@ def _env_step(runner_state, unused): keys = jax.random.split(_rng, env_state.observation.shape[0]) env_state = step_fn(env_state, action, keys) transition = Transition( - env_state.terminated, - action, - value, - jnp.squeeze(env_state.rewards), - log_prob, - last_obs + env_state.terminated, action, value, jnp.squeeze(env_state.rewards), log_prob, last_obs ) - runner_state = (arr_params, opt_state, env_state, - env_state.observation, rng) + runner_state = (arr_params, opt_state, env_state, env_state.observation, rng) return runner_state, transition - runner_state, traj_batch = jax.lax.scan( - _env_step, runner_state, None, args.num_steps - ) + runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, args.num_steps) runner_state = eqx.tree_at(lambda x: x[0], runner_state, eqx.combine(runner_state[0], static)) # CALCULATE ADVANTAGE @@ -185,10 +204,7 @@ def _get_advantages(gae_and_next_value, transition): transition.reward, ) delta = reward + args.gamma * next_value * (1 - done) - value - gae = ( - delta - + args.gamma * args.gae_lambda * (1 - done) * gae - ) + gae = delta + args.gamma * args.gae_lambda * (1 - done) * gae return (gae, value), gae _, advantages = jax.lax.scan( @@ -219,16 +235,12 @@ def _loss_fn(params, traj_batch, gae, targets): # log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, traj_batch.action) # CALCULATE VALUE LOSS - value_pred_clipped = traj_batch.value + ( - value - traj_batch.value - ).clip(-args.clip_eps, args.clip_eps) - value_losses = jnp.square(value - targets) - value_losses_clipped = jnp.square( - value_pred_clipped - targets) - value_loss = ( - 0.5 * jnp.maximum(value_losses, - value_losses_clipped).mean() + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -args.clip_eps, args.clip_eps ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() # CALCULATE ACTOR LOSS ratio = jnp.exp(log_prob - traj_batch.log_prob) @@ -246,16 +258,11 @@ def _loss_fn(params, traj_batch, gae, targets): loss_actor = loss_actor.mean() entropy = pi.entropy().mean() - total_loss = ( - loss_actor - + args.vf_coef * value_loss - - args.ent_coef * entropy - ) + total_loss = loss_actor + args.vf_coef * value_loss - args.ent_coef * entropy return total_loss, (value_loss, loss_actor, entropy) grad_fn = eqx.filter_value_and_grad(_loss_fn, has_aux=True) - total_loss, grads = grad_fn( - eqx.combine(params, static), traj_batch, advantages, targets) + total_loss, grads = grad_fn(eqx.combine(params, static), traj_batch, advantages, targets) updates, opt_state = optimizer.update(grads, opt_state) params = eqx.apply_updates(params, updates) return (params, opt_state), total_loss @@ -268,34 +275,23 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of envs" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( - lambda x: x.reshape((batch_size,) + x.shape[2:]), batch - ) - shuffled_batch = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=0), batch - ) + batch = jax.tree_util.tree_map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch) + shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch) minibatches = jax.tree_util.tree_map( - lambda x: jnp.reshape( - x, [num_minibatches, -1] + list(x.shape[1:]) - ), + lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])), shuffled_batch, ) - (params, opt_state), total_loss = jax.lax.scan( - _update_minbatch, (params, opt_state), minibatches - ) - update_state = (params, opt_state, traj_batch, - advantages, targets, rng) + (params, opt_state), total_loss = jax.lax.scan(_update_minbatch, (params, opt_state), minibatches) + update_state = (params, opt_state, traj_batch, advantages, targets, rng) return update_state, total_loss - update_state = (params_arr, opt_state, traj_batch, - advantages, targets, rng) - update_state, loss_info = jax.lax.scan( - _update_epoch, update_state, None, args.update_epochs - ) + update_state = (params_arr, opt_state, traj_batch, advantages, targets, rng) + update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, args.update_epochs) params, opt_state, _, _, _, rng = update_state runner_state = (eqx.combine(params, static), opt_state, env_state, last_obs, rng) return runner_state, loss_info + return _update_step @@ -327,6 +323,7 @@ def loop_fn(tup): keys = jax.random.split(_rng, state.observation.shape[0]) state = step_fn(state, action, keys) return state, R + state.rewards, rng_key + state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key)) return R.mean() From caa2fc57d0cd3763aefe0c105fd25c98b46a35e2 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Feb 2025 20:59:59 -0800 Subject: [PATCH 05/17] fix --- examples/minatar-ppo/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index 598194ed6..8060109c8 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -102,12 +102,13 @@ def __init__(self, num_actions, key, activation="tanh"): keys = jax.random.split(key, 8) self.features = [ truncated_normal_init(eqx.nn.Conv2d(env.observation_shape[2], 32, 2, key=keys[0]), keys[0]), - # (4, 10, 10) -> (32, 9, 9) + # (4, 10, 10) -> (32, 10, 10) jax.nn.relu, + lambda x: jnp.moveaxis(x, 0, -1), eqx.nn.AvgPool2d(2, 2), - # (32, 9, 9) -> (32, 4, 4) + # (10, 10, 32) -> (10, 5, 16) lambda x: x.flatten(), - truncated_normal_init(eqx.nn.Linear(32 * 4 * 4, 64, key=keys[1]), key=keys[1]), + truncated_normal_init(eqx.nn.Linear(10 * 5 * 16, 64, key=keys[1]), key=keys[1]), jax.nn.relu, ] From fbd30a211369613120bbe71017854445cb2cfb58 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Feb 2025 21:06:05 -0800 Subject: [PATCH 06/17] no init --- examples/minatar-ppo/train.py | 44 ++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index 8060109c8..2819fc3cf 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -100,32 +100,60 @@ def __init__(self, num_actions, key, activation="tanh"): act_fn = jax.nn.tanh keys = jax.random.split(key, 8) + # self.features = [ + # truncated_normal_init(eqx.nn.Conv2d(env.observation_shape[2], 32, 2, padding="SAME", key=keys[0]), keys[0]), + # # (4, 10, 10) -> (32, 10, 10) + # jax.nn.relu, + # lambda x: jnp.moveaxis(x, 0, -1), + # eqx.nn.AvgPool2d(2, 2), + # # (10, 10, 32) -> (10, 5, 16) + # lambda x: x.flatten(), + # truncated_normal_init(eqx.nn.Linear(10 * 5 * 16, 64, key=keys[1]), key=keys[1]), + # jax.nn.relu, + # ] + + # self.actor = [ + # truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[2]), keys[2]), + # act_fn, + # truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[3]), keys[3]), + # act_fn, + # truncated_normal_init(eqx.nn.Linear(64, num_actions, key=keys[4]), keys[4]), + # ] + + # self.critic = [ + # truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[5]), keys[5]), + # act_fn, + # truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[6]), keys[6]), + # act_fn, + # truncated_normal_init(eqx.nn.Linear(64, 1, key=keys[7]), keys[7]), + # ] + self.features = [ - truncated_normal_init(eqx.nn.Conv2d(env.observation_shape[2], 32, 2, key=keys[0]), keys[0]), + eqx.nn.Conv2d(env.observation_shape[2], 32, 2, padding="SAME", key=keys[0]), # (4, 10, 10) -> (32, 10, 10) jax.nn.relu, lambda x: jnp.moveaxis(x, 0, -1), eqx.nn.AvgPool2d(2, 2), # (10, 10, 32) -> (10, 5, 16) lambda x: x.flatten(), - truncated_normal_init(eqx.nn.Linear(10 * 5 * 16, 64, key=keys[1]), key=keys[1]), + eqx.nn.Linear(10 * 5 * 16, 64, key=keys[1]), jax.nn.relu, ] self.actor = [ - truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[2]), keys[2]), + eqx.nn.Linear(64, 64, key=keys[2]), act_fn, - truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[3]), keys[3]), + eqx.nn.Linear(64, 64, key=keys[3]), act_fn, - truncated_normal_init(eqx.nn.Linear(64, num_actions, key=keys[4]), keys[4]), + eqx.nn.Linear(64, num_actions, key=keys[4]), ] self.critic = [ - truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[5]), keys[5]), + eqx.nn.Linear(64, 64, key=keys[5]), act_fn, - truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[6]), keys[6]), + eqx.nn.Linear(64, 64, key=keys[6]), act_fn, - truncated_normal_init(eqx.nn.Linear(64, 1, key=keys[7]), keys[7]), + eqx.nn.Linear(64, 1, key=keys[7]), ] def __call__(self, x): From f32fd179f5f1024a22ecda89479e626b8216bedb Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Feb 2025 21:09:40 -0800 Subject: [PATCH 07/17] dist --- examples/minatar-ppo/train.py | 92 ++++++++--------------------------- 1 file changed, 19 insertions(+), 73 deletions(-) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index 2819fc3cf..b27bd74f8 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -60,33 +60,6 @@ class Config: num_minibatches = args.num_envs * args.num_steps // args.minibatch_size -def init_weight(layer, key): - def where(m): - return m.weight - - s = layer.weight.shape - if len(s) == 2: - f = s[1] - else: - f = s[1] * s[2] * s[3] - return eqx.tree_at(where, layer, (1.0 / jnp.sqrt(f)) * jax.random.truncated_normal(key, -2.0, 2.0, s)) - - -def init_bias(layer): - def where(m): - return m.bias - - if layer.bias is not None: - return eqx.tree_at(where, layer, jnp.zeros_like(layer.bias)) - return layer - - -def truncated_normal_init(layer, key): - layer = init_weight(layer, key) - layer = init_bias(layer) - return layer - - class ActorCritic(eqx.Module): features: list actor: list @@ -100,33 +73,6 @@ def __init__(self, num_actions, key, activation="tanh"): act_fn = jax.nn.tanh keys = jax.random.split(key, 8) - # self.features = [ - # truncated_normal_init(eqx.nn.Conv2d(env.observation_shape[2], 32, 2, padding="SAME", key=keys[0]), keys[0]), - # # (4, 10, 10) -> (32, 10, 10) - # jax.nn.relu, - # lambda x: jnp.moveaxis(x, 0, -1), - # eqx.nn.AvgPool2d(2, 2), - # # (10, 10, 32) -> (10, 5, 16) - # lambda x: x.flatten(), - # truncated_normal_init(eqx.nn.Linear(10 * 5 * 16, 64, key=keys[1]), key=keys[1]), - # jax.nn.relu, - # ] - - # self.actor = [ - # truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[2]), keys[2]), - # act_fn, - # truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[3]), keys[3]), - # act_fn, - # truncated_normal_init(eqx.nn.Linear(64, num_actions, key=keys[4]), keys[4]), - # ] - - # self.critic = [ - # truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[5]), keys[5]), - # act_fn, - # truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[6]), keys[6]), - # act_fn, - # truncated_normal_init(eqx.nn.Linear(64, 1, key=keys[7]), keys[7]), - # ] self.features = [ eqx.nn.Conv2d(env.observation_shape[2], 32, 2, padding="SAME", key=keys[0]), @@ -198,14 +144,14 @@ def _env_step(runner_state, unused): # SELECT ACTION rng, _rng = jax.random.split(rng) __rng = jax.random.split(_rng, last_obs.shape[0]) - logits, value = eqx.filter_vmap(params)(last_obs) - pi = distrax.Categorical(logits=logits) - action = pi.sample(seed=_rng) - log_prob = pi.log_prob(action) - # pi = eqx.filter_vmap(distributions.Categorical)(logits) - # action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) - # action = action.astype('int32') - # log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, action) + # logits, value = eqx.filter_vmap(params)(last_obs) + # pi = distrax.Categorical(logits=logits) + # action = pi.sample(seed=_rng) + # log_prob = pi.log_prob(action) + pi = eqx.filter_vmap(distributions.Categorical)(logits) + action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) + action = action.astype('int32') + log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, action) # STEP ENV rng, _rng = jax.random.split(rng) @@ -258,10 +204,10 @@ def _update_minbatch(tup, batch_info): def _loss_fn(params, traj_batch, gae, targets): # RERUN NETWORK logits, value = eqx.filter_vmap(params)(traj_batch.obs) - pi = distrax.Categorical(logits=logits) - log_prob = pi.log_prob(traj_batch.action) - # pi = eqx.filter_vmap(distributions.Categorical)(logits) - # log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, traj_batch.action) + # pi = distrax.Categorical(logits=logits) + # log_prob = pi.log_prob(traj_batch.action) + pi = eqx.filter_vmap(distributions.Categorical)(logits) + log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, traj_batch.action) # CALCULATE VALUE LOSS value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( @@ -340,14 +286,14 @@ def loop_fn(tup): state, R, rng_key = tup logits, value = eqx.filter_vmap(params)(state.observation) # action = logits.argmax(axis=-1) - pi = distrax.Categorical(logits=logits) - rng_key, _rng = jax.random.split(rng_key) - action = pi.sample(seed=_rng) - # pi = eqx.filter_vmap(distributions.Categorical)(logits) + # pi = distrax.Categorical(logits=logits) # rng_key, _rng = jax.random.split(rng_key) - # __rng = jax.random.split(_rng, state.observation.shape[0]) - # action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) - # action = action.astype('int32') + # action = pi.sample(seed=_rng) + pi = eqx.filter_vmap(distributions.Categorical)(logits) + rng_key, _rng = jax.random.split(rng_key) + __rng = jax.random.split(_rng, state.observation.shape[0]) + action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) + action = action.astype('int32') rng_key, _rng = jax.random.split(rng_key) keys = jax.random.split(_rng, state.observation.shape[0]) state = step_fn(state, action, keys) From 839730613d7755c121a52c9af27a27bd028558fc Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Feb 2025 21:10:50 -0800 Subject: [PATCH 08/17] fix --- examples/minatar-ppo/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index b27bd74f8..8d79f18ae 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -144,7 +144,7 @@ def _env_step(runner_state, unused): # SELECT ACTION rng, _rng = jax.random.split(rng) __rng = jax.random.split(_rng, last_obs.shape[0]) - # logits, value = eqx.filter_vmap(params)(last_obs) + logits, value = eqx.filter_vmap(params)(last_obs) # pi = distrax.Categorical(logits=logits) # action = pi.sample(seed=_rng) # log_prob = pi.log_prob(action) From 2bdc6ccd32bc3cdbb006aacf013ba8ad89518921 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Feb 2025 22:06:03 -0800 Subject: [PATCH 09/17] log --- examples/minatar-ppo/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index 8d79f18ae..f7d7c3ace 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -46,6 +46,8 @@ class PPOConfig(BaseModel): max_grad_norm: float = 0.5 wandb_project: str = "pgx-minatar-ppo" save_model: bool = False + equinox: bool = True + distrax: bool = False class Config: extra = "forbid" From 730a8eced84a9570eb7a7f1eedd66c9e4e7b5adb Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Feb 2025 22:11:13 -0800 Subject: [PATCH 10/17] init --- examples/minatar-ppo/train.py | 45 ++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index f7d7c3ace..ba84121b3 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -48,6 +48,7 @@ class PPOConfig(BaseModel): save_model: bool = False equinox: bool = True distrax: bool = False + changed_init_equinox: bool = True class Config: extra = "forbid" @@ -61,6 +62,32 @@ class Config: num_updates = args.total_timesteps // args.num_envs // args.num_steps num_minibatches = args.num_envs * args.num_steps // args.minibatch_size +def init_weight(layer, key): + def where(m): + return m.weight + + s = layer.weight.shape + if len(s) == 2: + f = s[1] + else: + f = s[1] * s[2] * s[3] + return eqx.tree_at(where, layer, (1.0 / jnp.sqrt(f)) * jax.random.truncated_normal(key, -2.0, 2.0, s)) + + +def init_bias(layer): + def where(m): + return m.bias + + if layer.bias is not None: + return eqx.tree_at(where, layer, jnp.zeros_like(layer.bias)) + return layer + + +def truncated_normal_init(layer, key): + layer = init_weight(layer, key) + layer = init_bias(layer) + return layer + class ActorCritic(eqx.Module): features: list @@ -75,33 +102,33 @@ def __init__(self, num_actions, key, activation="tanh"): act_fn = jax.nn.tanh keys = jax.random.split(key, 8) - + self.features = [ - eqx.nn.Conv2d(env.observation_shape[2], 32, 2, padding="SAME", key=keys[0]), + truncated_normal_init(eqx.nn.Conv2d(env.observation_shape[2], 32, 2, padding="SAME", key=keys[0]), keys[0]), # (4, 10, 10) -> (32, 10, 10) jax.nn.relu, lambda x: jnp.moveaxis(x, 0, -1), eqx.nn.AvgPool2d(2, 2), # (10, 10, 32) -> (10, 5, 16) lambda x: x.flatten(), - eqx.nn.Linear(10 * 5 * 16, 64, key=keys[1]), + truncated_normal_init(eqx.nn.Linear(10 * 5 * 16, 64, key=keys[1]), key=keys[1]), jax.nn.relu, ] self.actor = [ - eqx.nn.Linear(64, 64, key=keys[2]), + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[2]), keys[2]), act_fn, - eqx.nn.Linear(64, 64, key=keys[3]), + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[3]), keys[3]), act_fn, - eqx.nn.Linear(64, num_actions, key=keys[4]), + truncated_normal_init(eqx.nn.Linear(64, num_actions, key=keys[4]), keys[4]), ] self.critic = [ - eqx.nn.Linear(64, 64, key=keys[5]), + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[5]), keys[5]), act_fn, - eqx.nn.Linear(64, 64, key=keys[6]), + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[6]), keys[6]), act_fn, - eqx.nn.Linear(64, 1, key=keys[7]), + truncated_normal_init(eqx.nn.Linear(64, 1, key=keys[7]), keys[7]), ] def __call__(self, x): From bfcdabf2ad5ddad61c57f6f466e986a781777429 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Feb 2025 23:08:19 -0800 Subject: [PATCH 11/17] fix --- examples/alphazero/network.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/alphazero/network.py b/examples/alphazero/network.py index 67d68909d..391deaf96 100644 --- a/examples/alphazero/network.py +++ b/examples/alphazero/network.py @@ -53,8 +53,7 @@ def __call__(self, x, state): return x + i, state -class AZNet(eqx.Module): - """AlphaZero NN architecture.""" +class AZNete(eqx.Module): init_layers: list resnet: list @@ -74,7 +73,7 @@ def __init__( resnet_cls = BlockV2 if resnet_v2 else BlockV1 keys = jax.random.split(key, num_blocks + 5) - self.init_layers = [eqx.nn.Conv2d(input_channels, output_channels, kernel_size=3, key=keys[0])] + self.init_layers = [eqx.nn.Conv2d(input_channels, output_channels, kernel_size=3, padding="SAME", key=keys[0])] if not resnet_v2: self.init_layers += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] self.resnet = [resnet_cls(output_channels, output_channels, keys[i + 1]) for i in range(num_blocks)] @@ -82,20 +81,20 @@ def __init__( if resnet_v2: self.post_resnet += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] self.policy_head = [ - eqx.nn.Conv2d(output_channels, 2, kernel_size=1, key=keys[num_blocks + 1]), + eqx.nn.Conv2d(output_channels, 2, kernel_size=1, padding="SAME", key=keys[num_blocks + 1]), eqx.nn.BatchNorm(2, "batch", momentum=0.9), jax.nn.relu, lambda x: x.flatten(), - # TODO: infer 98 from inputs - eqx.nn.Linear(98, num_actions, key=keys[num_blocks + 2]), + # TODO: infer from inputs + eqx.nn.Linear(162, num_actions, key=keys[num_blocks + 2]), ] self.value_head = [ - eqx.nn.Conv2d(output_channels, 1, kernel_size=1, key=keys[num_blocks + 3]), + eqx.nn.Conv2d(output_channels, 1, kernel_size=1, padding="SAME", key=keys[num_blocks + 3]), eqx.nn.BatchNorm(1, "batch", momentum=0.9), jax.nn.relu, lambda x: x.flatten(), - eqx.nn.Linear(49, output_channels, key=keys[num_blocks + 2]), + eqx.nn.Linear(81, output_channels, key=keys[num_blocks + 2]), jax.nn.relu, eqx.nn.Linear(output_channels, 1, key=keys[num_blocks + 2]), jnp.tanh, From 909407a1059fdd87698342eba94ce7bd12759e7e Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Feb 2025 23:10:51 -0800 Subject: [PATCH 12/17] wandb fix --- examples/alphazero/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/alphazero/train.py b/examples/alphazero/train.py index 57477bad8..9e071c510 100644 --- a/examples/alphazero/train.py +++ b/examples/alphazero/train.py @@ -58,6 +58,7 @@ class Config(BaseModel): learning_rate: float = 0.001 # eval params eval_interval: int = 5 + wandb_project: str = "pgx-az" class Config: extra = "forbid" @@ -255,7 +256,7 @@ def body_fn(val): if __name__ == "__main__": - wandb.init(project="pgx-az", config=config.model_dump()) + wandb.init(project=config.wandb_project, config=config.model_dump()) # Initialize model and opt_state dummy_state = jax.vmap(env.init)(jax.random.split(jax.random.PRNGKey(0), 2)) From 88b2626239811d09dcf3680112b5d449011063eb Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Feb 2025 23:11:27 -0800 Subject: [PATCH 13/17] e --- examples/alphazero/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/alphazero/network.py b/examples/alphazero/network.py index 391deaf96..6e31787c2 100644 --- a/examples/alphazero/network.py +++ b/examples/alphazero/network.py @@ -53,7 +53,7 @@ def __call__(self, x, state): return x + i, state -class AZNete(eqx.Module): +class AZNet(eqx.Module): init_layers: list resnet: list From e8525f2fe9612ad4de209fb859fc57233cf60a52 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Fri, 7 Feb 2025 14:57:05 -0800 Subject: [PATCH 14/17] mode --- examples/alphazero/network.py | 16 ++++++++-------- examples/alphazero/train.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/alphazero/network.py b/examples/alphazero/network.py index 6e31787c2..19c55eae6 100644 --- a/examples/alphazero/network.py +++ b/examples/alphazero/network.py @@ -16,8 +16,8 @@ def __init__(self, in_channels, out_channels, key): keys = jax.random.split(key, 2) self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, padding="SAME", kernel_size=3, key=keys[0]) self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, padding="SAME", kernel_size=3, key=keys[1]) - self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) - self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9, mode="batch") + self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9, mode="batch") def __call__(self, x, state): i = x @@ -39,8 +39,8 @@ def __init__(self, in_channels, out_channels, key): keys = jax.random.split(key, 2) self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, padding="SAME", kernel_size=3, key=keys[0]) self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, padding="SAME", kernel_size=3, key=keys[1]) - self.norm1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9) - self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9) + self.norm1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, mode="batch") + self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9, mode="batch") def __call__(self, x, state): i = x @@ -75,14 +75,14 @@ def __init__( keys = jax.random.split(key, num_blocks + 5) self.init_layers = [eqx.nn.Conv2d(input_channels, output_channels, kernel_size=3, padding="SAME", key=keys[0])] if not resnet_v2: - self.init_layers += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] + self.init_layers += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9, mode="batch"), jax.nn.relu] self.resnet = [resnet_cls(output_channels, output_channels, keys[i + 1]) for i in range(num_blocks)] self.post_resnet = [] if resnet_v2: - self.post_resnet += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9), jax.nn.relu] + self.post_resnet += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9, mode="batch"), jax.nn.relu] self.policy_head = [ eqx.nn.Conv2d(output_channels, 2, kernel_size=1, padding="SAME", key=keys[num_blocks + 1]), - eqx.nn.BatchNorm(2, "batch", momentum=0.9), + eqx.nn.BatchNorm(2, "batch", momentum=0.9, mode="batch"), jax.nn.relu, lambda x: x.flatten(), # TODO: infer from inputs @@ -91,7 +91,7 @@ def __init__( self.value_head = [ eqx.nn.Conv2d(output_channels, 1, kernel_size=1, padding="SAME", key=keys[num_blocks + 3]), - eqx.nn.BatchNorm(1, "batch", momentum=0.9), + eqx.nn.BatchNorm(1, "batch", momentum=0.9, mode="batch"), jax.nn.relu, lambda x: x.flatten(), eqx.nn.Linear(81, output_channels, key=keys[num_blocks + 2]), diff --git a/examples/alphazero/train.py b/examples/alphazero/train.py index 9e071c510..b08d4351a 100644 --- a/examples/alphazero/train.py +++ b/examples/alphazero/train.py @@ -57,7 +57,7 @@ class Config(BaseModel): training_batch_size: int = 4096 learning_rate: float = 0.001 # eval params - eval_interval: int = 5 + eval_interval: int = 10 wandb_project: str = "pgx-az" class Config: From bf73dc777b7eac1c48e7edb238181737a528bbf4 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Sun, 9 Feb 2025 21:24:54 -0800 Subject: [PATCH 15/17] format + requirements --- examples/alphazero/requirements.txt | 3 ++- examples/alphazero/train.py | 17 +++-------------- examples/minatar-ppo/train.py | 10 ++++------ 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/examples/alphazero/requirements.txt b/examples/alphazero/requirements.txt index 7670b4fa6..6d93e85d7 100644 --- a/examples/alphazero/requirements.txt +++ b/examples/alphazero/requirements.txt @@ -1,7 +1,8 @@ pgx>=2.0.0 -dm-haiku +equinox mctx optax wandb omegaconf pydantic +cloundpickle \ No newline at end of file diff --git a/examples/alphazero/train.py b/examples/alphazero/train.py index b08d4351a..c14e29a62 100644 --- a/examples/alphazero/train.py +++ b/examples/alphazero/train.py @@ -12,14 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# import os - -# os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(4) - - import datetime import os -import pickle +import cloudpickle as pickle import time from functools import partial from typing import NamedTuple @@ -271,13 +266,7 @@ def body_fn(val): ) opt_state = optimizer.init(eqx.filter(init_model, eqx.is_array)) # replicates to all devices - arr, static = eqx.partition((init_model, opt_state, state), eqx.is_array) - # arr_rep = jax.device_put_replicated(arr, devices) - arr_rep = arr - train_model = eqx.combine(arr_rep[0], static[0]) - opt_state = eqx.combine(arr_rep[1], static[1]) - state = eqx.combine(arr_rep[2], static[2]) - model = (train_model, state) + model = (init_model, state) # Prepare checkpoint dir now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=9))) @@ -323,7 +312,7 @@ def body_fn(val): "env_id": env.id, "env_version": env.version, } - # pickle.dump(dic, f) + pickle.dump(dic, f) print(log) wandb.log(log) diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index ba84121b3..030182b0b 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -46,9 +46,6 @@ class PPOConfig(BaseModel): max_grad_norm: float = 0.5 wandb_project: str = "pgx-minatar-ppo" save_model: bool = False - equinox: bool = True - distrax: bool = False - changed_init_equinox: bool = True class Config: extra = "forbid" @@ -62,6 +59,7 @@ class Config: num_updates = args.total_timesteps // args.num_envs // args.num_steps num_minibatches = args.num_envs * args.num_steps // args.minibatch_size + def init_weight(layer, key): def where(m): return m.weight @@ -102,7 +100,7 @@ def __init__(self, num_actions, key, activation="tanh"): act_fn = jax.nn.tanh keys = jax.random.split(key, 8) - + self.features = [ truncated_normal_init(eqx.nn.Conv2d(env.observation_shape[2], 32, 2, padding="SAME", key=keys[0]), keys[0]), # (4, 10, 10) -> (32, 10, 10) @@ -179,7 +177,7 @@ def _env_step(runner_state, unused): # log_prob = pi.log_prob(action) pi = eqx.filter_vmap(distributions.Categorical)(logits) action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) - action = action.astype('int32') + action = action.astype("int32") log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, action) # STEP ENV @@ -322,7 +320,7 @@ def loop_fn(tup): rng_key, _rng = jax.random.split(rng_key) __rng = jax.random.split(_rng, state.observation.shape[0]) action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) - action = action.astype('int32') + action = action.astype("int32") rng_key, _rng = jax.random.split(rng_key) keys = jax.random.split(_rng, state.observation.shape[0]) state = step_fn(state, action, keys) From 84f001f6040b264a7292554acde157abc709cf04 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Sun, 9 Feb 2025 21:26:21 -0800 Subject: [PATCH 16/17] setting --- .vscode/settings.json | 2 +- requirements/requirements-dev.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 032139fa8..5b58514fe 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -29,4 +29,4 @@ "python.terminal.activateEnvInCurrentTerminal": true, "python.linting.enabled": false, "python.linting.pylintEnabled": true -} +} \ No newline at end of file diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index c67fa62f5..871587745 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -3,6 +3,7 @@ pytest pytest-xdist matplotlib ipython +dm-haiku==0.0.10 equinox pytest-cov pgx-minatar From b403e0a6791e19928923772df0caa5ee68ebaf56 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Sun, 9 Feb 2025 21:27:03 -0800 Subject: [PATCH 17/17] settings --- .vscode/settings.json | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 5b58514fe..40b5d0774 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -24,9 +24,7 @@ "--config", "pyproject.toml", "--ignore", "E203,E501,W503" ], - "python.linting.flake8Enabled": false, + "python.linting.flake8Enabled": true, "python.defaultInterpreterPath": "${workspace}/venv/bin/python3", - "python.terminal.activateEnvInCurrentTerminal": true, - "python.linting.enabled": false, - "python.linting.pylintEnabled": true + "python.terminal.activateEnvInCurrentTerminal": true } \ No newline at end of file