diff --git a/.vscode/settings.json b/.vscode/settings.json index 14ed2a311..40b5d0774 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -27,4 +27,4 @@ "python.linting.flake8Enabled": true, "python.defaultInterpreterPath": "${workspace}/venv/bin/python3", "python.terminal.activateEnvInCurrentTerminal": true -} +} \ No newline at end of file diff --git a/examples/alphazero/network.py b/examples/alphazero/network.py index e6ca93004..19c55eae6 100644 --- a/examples/alphazero/network.py +++ b/examples/alphazero/network.py @@ -1,93 +1,137 @@ # We referred to Haiku's ResNet implementation: # https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/nets/resnet.py -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp -class BlockV1(hk.Module): - def __init__(self, num_channels, name="BlockV1"): - super(BlockV1, self).__init__(name=name) - self.num_channels = num_channels +class BlockV1(eqx.Module): + conv1: eqx.nn.Conv2d + conv2: eqx.nn.Conv2d + norm1: eqx.nn.BatchNorm + norm2: eqx.nn.BatchNorm - def __call__(self, x, is_training, test_local_stats): + def __init__(self, in_channels, out_channels, key): + keys = jax.random.split(key, 2) + self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, padding="SAME", kernel_size=3, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, padding="SAME", kernel_size=3, key=keys[1]) + self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9, mode="batch") + self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9, mode="batch") + + def __call__(self, x, state): i = x - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) + x = self.conv1(x) + x, state = self.norm1(x, state) x = jax.nn.relu(x) - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) - return jax.nn.relu(x + i) + x = self.conv2(x) + x, state = self.norm2(x, state) + return jax.nn.relu(x + i), state + +class BlockV2(eqx.Module): + conv1: eqx.nn.Conv2d + conv2: eqx.nn.Conv2d + norm1: eqx.nn.BatchNorm + norm2: eqx.nn.BatchNorm -class BlockV2(hk.Module): - def __init__(self, num_channels, name="BlockV2"): - super(BlockV2, self).__init__(name=name) - self.num_channels = num_channels + def __init__(self, in_channels, out_channels, key): + keys = jax.random.split(key, 2) + self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, padding="SAME", kernel_size=3, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, padding="SAME", kernel_size=3, key=keys[1]) + self.norm1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, mode="batch") + self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9, mode="batch") - def __call__(self, x, is_training, test_local_stats): + def __call__(self, x, state): i = x - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) + x, state = self.norm1(x, state) x = jax.nn.relu(x) - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) + x = self.conv1(x) + x, state = self.norm2(x, state) x = jax.nn.relu(x) - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - return x + i + x = self.conv2(x) + return x + i, state + +class AZNet(eqx.Module): -class AZNet(hk.Module): - """AlphaZero NN architecture.""" + init_layers: list + resnet: list + post_resnet: list + policy_head: list + value_head: list def __init__( self, num_actions, - num_channels: int = 64, + input_channels, + key, + output_channels: int = 64, num_blocks: int = 5, resnet_v2: bool = True, - name="az_net", ): - super().__init__(name=name) - self.num_actions = num_actions - self.num_channels = num_channels - self.num_blocks = num_blocks - self.resnet_v2 = resnet_v2 - self.resnet_cls = BlockV2 if resnet_v2 else BlockV1 - - def __call__(self, x, is_training, test_local_stats): + resnet_cls = BlockV2 if resnet_v2 else BlockV1 + + keys = jax.random.split(key, num_blocks + 5) + self.init_layers = [eqx.nn.Conv2d(input_channels, output_channels, kernel_size=3, padding="SAME", key=keys[0])] + if not resnet_v2: + self.init_layers += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9, mode="batch"), jax.nn.relu] + self.resnet = [resnet_cls(output_channels, output_channels, keys[i + 1]) for i in range(num_blocks)] + self.post_resnet = [] + if resnet_v2: + self.post_resnet += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9, mode="batch"), jax.nn.relu] + self.policy_head = [ + eqx.nn.Conv2d(output_channels, 2, kernel_size=1, padding="SAME", key=keys[num_blocks + 1]), + eqx.nn.BatchNorm(2, "batch", momentum=0.9, mode="batch"), + jax.nn.relu, + lambda x: x.flatten(), + # TODO: infer from inputs + eqx.nn.Linear(162, num_actions, key=keys[num_blocks + 2]), + ] + + self.value_head = [ + eqx.nn.Conv2d(output_channels, 1, kernel_size=1, padding="SAME", key=keys[num_blocks + 3]), + eqx.nn.BatchNorm(1, "batch", momentum=0.9, mode="batch"), + jax.nn.relu, + lambda x: x.flatten(), + eqx.nn.Linear(81, output_channels, key=keys[num_blocks + 2]), + jax.nn.relu, + eqx.nn.Linear(output_channels, 1, key=keys[num_blocks + 2]), + jnp.tanh, + jnp.squeeze, + ] + + def __call__(self, x, state): x = x.astype(jnp.float32) - x = hk.Conv2D(self.num_channels, kernel_shape=3)(x) - - if not self.resnet_v2: - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) - x = jax.nn.relu(x) - - for i in range(self.num_blocks): - x = self.resnet_cls(self.num_channels, name=f"block_{i}")( - x, is_training, test_local_stats - ) - - if self.resnet_v2: - x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats) - x = jax.nn.relu(x) - - # policy head - logits = hk.Conv2D(output_channels=2, kernel_shape=1)(x) - logits = hk.BatchNorm(True, True, 0.9)(logits, is_training, test_local_stats) - logits = jax.nn.relu(logits) - logits = hk.Flatten()(logits) - logits = hk.Linear(self.num_actions)(logits) - - # value head - v = hk.Conv2D(output_channels=1, kernel_shape=1)(x) - v = hk.BatchNorm(True, True, 0.9)(v, is_training, test_local_stats) - v = jax.nn.relu(v) - v = hk.Flatten()(v) - v = hk.Linear(self.num_channels)(v) - v = jax.nn.relu(v) - v = hk.Linear(1)(v) - v = jnp.tanh(v) - v = v.reshape((-1,)) - - return logits, v + x = jnp.moveaxis(x, -1, 0) + + for layer in self.init_layers: + if isinstance(layer, eqx.nn.StatefulLayer): + x, state = layer(x, state) + else: + x = layer(x) + + for layer in self.resnet: + x, state = layer(x, state) + + for layer in self.post_resnet: + if isinstance(layer, eqx.nn.StatefulLayer): + x, state = layer(x, state) + else: + x = layer(x) + + logits = x.copy() + for layer in self.policy_head: + if isinstance(layer, eqx.nn.StatefulLayer): + logits, state = layer(logits, state) + else: + logits = layer(logits) + + v = x.copy() + for layer in self.value_head: + if isinstance(layer, eqx.nn.StatefulLayer): + v, state = layer(v, state) + else: + v = layer(v) + + return (logits, v), state diff --git a/examples/alphazero/requirements.txt b/examples/alphazero/requirements.txt index 7670b4fa6..6d93e85d7 100644 --- a/examples/alphazero/requirements.txt +++ b/examples/alphazero/requirements.txt @@ -1,7 +1,8 @@ pgx>=2.0.0 -dm-haiku +equinox mctx optax wandb omegaconf pydantic +cloundpickle \ No newline at end of file diff --git a/examples/alphazero/train.py b/examples/alphazero/train.py index a80e5d9ec..c14e29a62 100644 --- a/examples/alphazero/train.py +++ b/examples/alphazero/train.py @@ -14,12 +14,11 @@ import datetime import os -import pickle +import cloudpickle as pickle import time from functools import partial from typing import NamedTuple -import haiku as hk import jax import jax.numpy as jnp import mctx @@ -29,6 +28,7 @@ from omegaconf import OmegaConf from pgx.experimental import auto_reset from pydantic import BaseModel +import equinox as eqx from network import AZNet @@ -52,7 +52,8 @@ class Config(BaseModel): training_batch_size: int = 4096 learning_rate: float = 0.001 # eval params - eval_interval: int = 5 + eval_interval: int = 10 + wandb_project: str = "pgx-az" class Config: extra = "forbid" @@ -65,50 +66,9 @@ class Config: env = pgx.make(config.env_id) baseline = pgx.make_baseline_model(config.env_id + "_v0") - -def forward_fn(x, is_eval=False): - net = AZNet( - num_actions=env.num_actions, - num_channels=config.num_channels, - num_blocks=config.num_layers, - resnet_v2=config.resnet_v2, - ) - policy_out, value_out = net(x, is_training=not is_eval, test_local_stats=False) - return policy_out, value_out - - -forward = hk.without_apply_rng(hk.transform_with_state(forward_fn)) optimizer = optax.adam(learning_rate=config.learning_rate) -def recurrent_fn(model, rng_key: jnp.ndarray, action: jnp.ndarray, state: pgx.State): - # model: params - # state: embedding - del rng_key - model_params, model_state = model - - current_player = state.current_player - state = jax.vmap(env.step)(state, action) - - (logits, value), _ = forward.apply(model_params, model_state, state.observation, is_eval=True) - # mask invalid actions - logits = logits - jnp.max(logits, axis=-1, keepdims=True) - logits = jnp.where(state.legal_action_mask, logits, jnp.finfo(logits.dtype).min) - - reward = state.rewards[jnp.arange(state.rewards.shape[0]), current_player] - value = jnp.where(state.terminated, 0.0, value) - discount = -1.0 * jnp.ones_like(value) - discount = jnp.where(state.terminated, 0.0, discount) - - recurrent_fn_output = mctx.RecurrentFnOutput( - reward=reward, - discount=discount, - prior_logits=logits, - value=value, - ) - return recurrent_fn_output, state - - class SelfplayOutput(NamedTuple): obs: jnp.ndarray reward: jnp.ndarray @@ -117,22 +77,55 @@ class SelfplayOutput(NamedTuple): discount: jnp.ndarray -@jax.pmap +@partial(eqx.filter_pmap, in_axes=(None, 0)) def selfplay(model, rng_key: jnp.ndarray) -> SelfplayOutput: model_params, model_state = model + model_params = eqx.nn.inference_mode(model_params) + model = (model_params, model_state) + arr, static = eqx.partition(model, eqx.is_array) + + def recurrent_fn(model, rng_key: jnp.ndarray, action: jnp.ndarray, state: pgx.State): + del rng_key + model = eqx.combine(model, static) + model_params, model_state = model + + current_player = state.current_player + state = jax.vmap(env.step)(state, action) + + # (logits, value), _ = forward.apply(model_params, model_state, state.observation, is_eval=True) + (logits, value), _ = eqx.filter_vmap(model_params, in_axes=(0, None), out_axes=(0, None), axis_name="batch")( + state.observation, model_state + ) + # mask invalid actions + logits = logits - jnp.max(logits, axis=-1, keepdims=True) + logits = jnp.where(state.legal_action_mask, logits, jnp.finfo(logits.dtype).min) + + reward = state.rewards[jnp.arange(state.rewards.shape[0]), current_player] + value = jnp.where(state.terminated, 0.0, value) + discount = -1.0 * jnp.ones_like(value) + discount = jnp.where(state.terminated, 0.0, discount) + + recurrent_fn_output = mctx.RecurrentFnOutput( + reward=reward, + discount=discount, + prior_logits=logits, + value=value, + ) + return recurrent_fn_output, state + batch_size = config.selfplay_batch_size // num_devices def step_fn(state, key) -> SelfplayOutput: key1, key2 = jax.random.split(key) observation = state.observation - (logits, value), _ = forward.apply( - model_params, model_state, state.observation, is_eval=True + (logits, value), _ = eqx.filter_vmap(model_params, in_axes=(0, None), out_axes=(0, None), axis_name="batch")( + state.observation, model_state ) root = mctx.RootFnOutput(prior_logits=logits, value=value, embedding=state) policy_output = mctx.gumbel_muzero_policy( - params=model, + params=arr, rng_key=key1, root=root, recurrent_fn=recurrent_fn, @@ -200,9 +193,9 @@ def body_fn(carry, i): def loss_fn(model_params, model_state, samples: Sample): - (logits, value), model_state = forward.apply( - model_params, model_state, samples.obs, is_eval=False - ) + (logits, value), model_state = eqx.filter_vmap( + model_params, in_axes=(0, None), out_axes=(0, None), axis_name="batch" + )(samples.obs, model_state) policy_loss = optax.softmax_cross_entropy(logits, samples.policy_tgt) policy_loss = jnp.mean(policy_loss) @@ -213,25 +206,26 @@ def loss_fn(model_params, model_state, samples: Sample): return policy_loss + value_loss, (model_state, policy_loss, value_loss) -@partial(jax.pmap, axis_name="i") +@partial(eqx.filter_pmap, axis_name="i", in_axes=(None, None, 0), out_axes=(None, None, 0, 0)) def train(model, opt_state, data: Sample): model_params, model_state = model - grads, (model_state, policy_loss, value_loss) = jax.grad(loss_fn, has_aux=True)( + grads, (model_state, policy_loss, value_loss) = eqx.filter_grad(loss_fn, has_aux=True)( model_params, model_state, data ) grads = jax.lax.pmean(grads, axis_name="i") updates, opt_state = optimizer.update(grads, opt_state) - model_params = optax.apply_updates(model_params, updates) + model_params = eqx.apply_updates(model_params, updates) model = (model_params, model_state) return model, opt_state, policy_loss, value_loss -@jax.pmap +@partial(eqx.filter_pmap, in_axes=(0, None)) def evaluate(rng_key, my_model): - """A simplified evaluation by sampling. Only for debugging. + """A simplified evaluation by sampling. Only for debugging. Please use MCTS and run tournaments for serious evaluation.""" my_player = 0 - my_model_params, my_model_state = my_model + my_model, my_model_state = my_model + inference_model = eqx.nn.inference_mode(my_model) key, subkey = jax.random.split(rng_key) batch_size = config.selfplay_batch_size // num_devices @@ -240,8 +234,8 @@ def evaluate(rng_key, my_model): def body_fn(val): key, state, R = val - (my_logits, _), _ = forward.apply( - my_model_params, my_model_state, state.observation, is_eval=True + (my_logits, _), _ = eqx.filter_vmap(inference_model, in_axes=(0, None), out_axes=(0, None), axis_name="batch")( + state.observation, my_model_state ) opp_logits, _ = baseline(state.observation) is_my_turn = (state.current_player == my_player).reshape((-1, 1)) @@ -252,22 +246,27 @@ def body_fn(val): R = R + state.rewards[jnp.arange(batch_size), my_player] return (key, state, R) - _, _, R = jax.lax.while_loop( - lambda x: ~(x[1].terminated.all()), body_fn, (key, state, jnp.zeros(batch_size)) - ) + _, _, R = jax.lax.while_loop(lambda x: ~(x[1].terminated.all()), body_fn, (key, state, jnp.zeros(batch_size))) return R if __name__ == "__main__": - wandb.init(project="pgx-az", config=config.model_dump()) + wandb.init(project=config.wandb_project, config=config.model_dump()) # Initialize model and opt_state dummy_state = jax.vmap(env.init)(jax.random.split(jax.random.PRNGKey(0), 2)) dummy_input = dummy_state.observation - model = forward.init(jax.random.PRNGKey(0), dummy_input) # (params, state) - opt_state = optimizer.init(params=model[0]) + init_model, state = eqx.nn.make_with_state(AZNet)( + env.num_actions, + env.observation_shape[-1], + jax.random.key(config.seed), + config.num_channels, + config.num_layers, + config.resnet_v2, + ) + opt_state = optimizer.init(eqx.filter(init_model, eqx.is_array)) # replicates to all devices - model, opt_state = jax.device_put_replicated((model, opt_state), devices) + model = (init_model, state) # Prepare checkpoint dir now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=9))) @@ -281,7 +280,7 @@ def body_fn(val): frames: int = 0 log = {"iteration": iteration, "hours": hours, "frames": frames} - rng_key = jax.random.PRNGKey(config.seed) + rng_key = jax.random.key(config.seed) while True: if iteration % config.eval_interval == 0: # Evaluation @@ -298,7 +297,8 @@ def body_fn(val): ) # Store checkpoints - model_0, opt_state_0 = jax.tree_util.tree_map(lambda x: x[0], (model, opt_state)) + # model_0, opt_state_0 = jax.tree_util.tree_map(lambda x: x[0], (train_model, opt_state)) + model_0, opt_state_0 = eqx.filter((model[0], opt_state), eqx.is_array) with open(os.path.join(ckpt_dir, f"{iteration:06d}.ckpt"), "wb") as f: dic = { "config": config, @@ -338,9 +338,7 @@ def body_fn(val): ixs = jax.random.permutation(subkey, jnp.arange(samples.obs.shape[0])) samples = jax.tree_util.tree_map(lambda x: x[ixs], samples) # shuffle num_updates = samples.obs.shape[0] // config.training_batch_size - minibatches = jax.tree_util.tree_map( - lambda x: x.reshape((num_updates, num_devices, -1) + x.shape[1:]), samples - ) + minibatches = jax.tree_util.tree_map(lambda x: x.reshape((num_updates, num_devices, -1) + x.shape[1:]), samples) # Training policy_losses, value_losses = [], [] @@ -349,6 +347,7 @@ def body_fn(val): model, opt_state, policy_loss, value_loss = train(model, opt_state, minibatch) policy_losses.append(policy_loss.mean().item()) value_losses.append(value_loss.mean().item()) + policy_loss = sum(policy_losses) / len(policy_losses) value_loss = sum(value_losses) / len(value_losses) diff --git a/examples/minatar-ppo/requirements.txt b/examples/minatar-ppo/requirements.txt index af303b4c6..2020f983b 100644 --- a/examples/minatar-ppo/requirements.txt +++ b/examples/minatar-ppo/requirements.txt @@ -1,7 +1,7 @@ pgx>=2.0.0 -dm-haiku +equinox optax -distrax +distreqx omegaconf pydantic wandb diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py index 173518c09..030182b0b 100644 --- a/examples/minatar-ppo/train.py +++ b/examples/minatar-ppo/train.py @@ -7,9 +7,10 @@ import sys import jax import jax.numpy as jnp -import haiku as hk +import equinox as eqx import optax from typing import NamedTuple, Literal +from distreqx import distributions import distrax import pgx from pgx.experimental import auto_reset @@ -59,52 +60,91 @@ class Config: num_minibatches = args.num_envs * args.num_steps // args.minibatch_size -class ActorCritic(hk.Module): - def __init__(self, num_actions, activation="tanh"): - super().__init__() - self.num_actions = num_actions - self.activation = activation - assert activation in ["relu", "tanh"] +def init_weight(layer, key): + def where(m): + return m.weight - def __call__(self, x): - x = x.astype(jnp.float32) - if self.activation == "relu": - activation = jax.nn.relu - else: - activation = jax.nn.tanh - x = hk.Conv2D(32, kernel_shape=2)(x) - x = jax.nn.relu(x) - x = hk.avg_pool(x, window_shape=(2, 2), - strides=(2, 2), padding="VALID") - x = x.reshape((x.shape[0], -1)) # flatten - x = hk.Linear(64)(x) - x = jax.nn.relu(x) - actor_mean = hk.Linear(64)(x) - actor_mean = activation(actor_mean) - actor_mean = hk.Linear(64)(actor_mean) - actor_mean = activation(actor_mean) - actor_mean = hk.Linear(self.num_actions)(actor_mean) - - critic = hk.Linear(64)(x) - critic = activation(critic) - critic = hk.Linear(64)(critic) - critic = activation(critic) - critic = hk.Linear(1)(critic) + s = layer.weight.shape + if len(s) == 2: + f = s[1] + else: + f = s[1] * s[2] * s[3] + return eqx.tree_at(where, layer, (1.0 / jnp.sqrt(f)) * jax.random.truncated_normal(key, -2.0, 2.0, s)) - return actor_mean, jnp.squeeze(critic, axis=-1) + +def init_bias(layer): + def where(m): + return m.bias + + if layer.bias is not None: + return eqx.tree_at(where, layer, jnp.zeros_like(layer.bias)) + return layer -def forward_fn(x, is_eval=False): - net = ActorCritic(env.num_actions, activation="tanh") - logits, value = net(x) - return logits, value +def truncated_normal_init(layer, key): + layer = init_weight(layer, key) + layer = init_bias(layer) + return layer -forward = hk.without_apply_rng(hk.transform(forward_fn)) +class ActorCritic(eqx.Module): + features: list + actor: list + critic: list + def __init__(self, num_actions, key, activation="tanh"): + assert activation in ["relu", "tanh"] + if activation == "relu": + act_fn = jax.nn.relu + else: + act_fn = jax.nn.tanh + + keys = jax.random.split(key, 8) + + self.features = [ + truncated_normal_init(eqx.nn.Conv2d(env.observation_shape[2], 32, 2, padding="SAME", key=keys[0]), keys[0]), + # (4, 10, 10) -> (32, 10, 10) + jax.nn.relu, + lambda x: jnp.moveaxis(x, 0, -1), + eqx.nn.AvgPool2d(2, 2), + # (10, 10, 32) -> (10, 5, 16) + lambda x: x.flatten(), + truncated_normal_init(eqx.nn.Linear(10 * 5 * 16, 64, key=keys[1]), key=keys[1]), + jax.nn.relu, + ] + + self.actor = [ + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[2]), keys[2]), + act_fn, + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[3]), keys[3]), + act_fn, + truncated_normal_init(eqx.nn.Linear(64, num_actions, key=keys[4]), keys[4]), + ] + + self.critic = [ + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[5]), keys[5]), + act_fn, + truncated_normal_init(eqx.nn.Linear(64, 64, key=keys[6]), keys[6]), + act_fn, + truncated_normal_init(eqx.nn.Linear(64, 1, key=keys[7]), keys[7]), + ] + + def __call__(self, x): + x = x.astype(jnp.float32) + # make channels first + x = jnp.moveaxis(x, -1, 0) + for layer in self.features: + x = layer(x) + actor_mean = jnp.copy(x) + for layer in self.actor: + actor_mean = layer(actor_mean) + critic = jnp.copy(x) + for layer in self.critic: + critic = layer(critic) + return actor_mean, jnp.squeeze(critic, axis=-1) -optimizer = optax.chain(optax.clip_by_global_norm( - args.max_grad_norm), optax.adam(args.lr, eps=1e-5)) + +optimizer = optax.chain(optax.clip_by_global_norm(args.max_grad_norm), optax.adam(args.lr, eps=1e-5)) class Transition(NamedTuple): @@ -122,38 +162,40 @@ def _update_step(runner_state): # COLLECT TRAJECTORIES step_fn = jax.vmap(auto_reset(env.step, env.init)) + arrs, static = eqx.partition(runner_state[0], eqx.is_array) + runner_state = eqx.tree_at(lambda x: x[0], runner_state, arrs) + def _env_step(runner_state, unused): - params, opt_state, env_state, last_obs, rng = runner_state + arr_params, opt_state, env_state, last_obs, rng = runner_state + params = eqx.combine(arr_params, static) # SELECT ACTION rng, _rng = jax.random.split(rng) - logits, value = forward.apply(params, last_obs) - pi = distrax.Categorical(logits=logits) - action = pi.sample(seed=_rng) - log_prob = pi.log_prob(action) + __rng = jax.random.split(_rng, last_obs.shape[0]) + logits, value = eqx.filter_vmap(params)(last_obs) + # pi = distrax.Categorical(logits=logits) + # action = pi.sample(seed=_rng) + # log_prob = pi.log_prob(action) + pi = eqx.filter_vmap(distributions.Categorical)(logits) + action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) + action = action.astype("int32") + log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, action) # STEP ENV rng, _rng = jax.random.split(rng) keys = jax.random.split(_rng, env_state.observation.shape[0]) env_state = step_fn(env_state, action, keys) transition = Transition( - env_state.terminated, - action, - value, - jnp.squeeze(env_state.rewards), - log_prob, - last_obs + env_state.terminated, action, value, jnp.squeeze(env_state.rewards), log_prob, last_obs ) - runner_state = (params, opt_state, env_state, - env_state.observation, rng) + runner_state = (arr_params, opt_state, env_state, env_state.observation, rng) return runner_state, transition - runner_state, traj_batch = jax.lax.scan( - _env_step, runner_state, None, args.num_steps - ) + runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, args.num_steps) + runner_state = eqx.tree_at(lambda x: x[0], runner_state, eqx.combine(runner_state[0], static)) # CALCULATE ADVANTAGE params, opt_state, env_state, last_obs, rng = runner_state - _, last_val = forward.apply(params, last_obs) + _, last_val = eqx.filter_vmap(params)(last_obs) def _calculate_gae(traj_batch, last_val): def _get_advantages(gae_and_next_value, transition): @@ -164,10 +206,7 @@ def _get_advantages(gae_and_next_value, transition): transition.reward, ) delta = reward + args.gamma * next_value * (1 - done) - value - gae = ( - delta - + args.gamma * args.gae_lambda * (1 - done) * gae - ) + gae = delta + args.gamma * args.gae_lambda * (1 - done) * gae return (gae, value), gae _, advantages = jax.lax.scan( @@ -181,6 +220,8 @@ def _get_advantages(gae_and_next_value, transition): advantages, targets = _calculate_gae(traj_batch, last_val) + params_arr, static = eqx.partition(params, eqx.is_array) + # UPDATE NETWORK def _update_epoch(update_state, unused): def _update_minbatch(tup, batch_info): @@ -189,21 +230,19 @@ def _update_minbatch(tup, batch_info): def _loss_fn(params, traj_batch, gae, targets): # RERUN NETWORK - logits, value = forward.apply(params, traj_batch.obs) - pi = distrax.Categorical(logits=logits) - log_prob = pi.log_prob(traj_batch.action) + logits, value = eqx.filter_vmap(params)(traj_batch.obs) + # pi = distrax.Categorical(logits=logits) + # log_prob = pi.log_prob(traj_batch.action) + pi = eqx.filter_vmap(distributions.Categorical)(logits) + log_prob = eqx.filter_vmap(lambda x, y: x.log_prob(y))(pi, traj_batch.action) # CALCULATE VALUE LOSS - value_pred_clipped = traj_batch.value + ( - value - traj_batch.value - ).clip(-args.clip_eps, args.clip_eps) - value_losses = jnp.square(value - targets) - value_losses_clipped = jnp.square( - value_pred_clipped - targets) - value_loss = ( - 0.5 * jnp.maximum(value_losses, - value_losses_clipped).mean() + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -args.clip_eps, args.clip_eps ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() # CALCULATE ACTOR LOSS ratio = jnp.exp(log_prob - traj_batch.log_prob) @@ -221,18 +260,13 @@ def _loss_fn(params, traj_batch, gae, targets): loss_actor = loss_actor.mean() entropy = pi.entropy().mean() - total_loss = ( - loss_actor - + args.vf_coef * value_loss - - args.ent_coef * entropy - ) + total_loss = loss_actor + args.vf_coef * value_loss - args.ent_coef * entropy return total_loss, (value_loss, loss_actor, entropy) - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - total_loss, grads = grad_fn( - params, traj_batch, advantages, targets) + grad_fn = eqx.filter_value_and_grad(_loss_fn, has_aux=True) + total_loss, grads = grad_fn(eqx.combine(params, static), traj_batch, advantages, targets) updates, opt_state = optimizer.update(grads, opt_state) - params = optax.apply_updates(params, updates) + params = eqx.apply_updates(params, updates) return (params, opt_state), total_loss params, opt_state, traj_batch, advantages, targets, rng = update_state @@ -243,38 +277,27 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of envs" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( - lambda x: x.reshape((batch_size,) + x.shape[2:]), batch - ) - shuffled_batch = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=0), batch - ) + batch = jax.tree_util.tree_map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch) + shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch) minibatches = jax.tree_util.tree_map( - lambda x: jnp.reshape( - x, [num_minibatches, -1] + list(x.shape[1:]) - ), + lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])), shuffled_batch, ) - (params, opt_state), total_loss = jax.lax.scan( - _update_minbatch, (params, opt_state), minibatches - ) - update_state = (params, opt_state, traj_batch, - advantages, targets, rng) + (params, opt_state), total_loss = jax.lax.scan(_update_minbatch, (params, opt_state), minibatches) + update_state = (params, opt_state, traj_batch, advantages, targets, rng) return update_state, total_loss - update_state = (params, opt_state, traj_batch, - advantages, targets, rng) - update_state, loss_info = jax.lax.scan( - _update_epoch, update_state, None, args.update_epochs - ) + update_state = (params_arr, opt_state, traj_batch, advantages, targets, rng) + update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, args.update_epochs) params, opt_state, _, _, _, rng = update_state - runner_state = (params, opt_state, env_state, last_obs, rng) + runner_state = (eqx.combine(params, static), opt_state, env_state, last_obs, rng) return runner_state, loss_info + return _update_step -@jax.jit +@eqx.filter_jit def evaluate(params, rng_key): step_fn = jax.vmap(env.step) rng_key, sub_key = jax.random.split(rng_key) @@ -288,15 +311,21 @@ def cond_fn(tup): def loop_fn(tup): state, R, rng_key = tup - logits, value = forward.apply(params, state.observation) + logits, value = eqx.filter_vmap(params)(state.observation) # action = logits.argmax(axis=-1) - pi = distrax.Categorical(logits=logits) + # pi = distrax.Categorical(logits=logits) + # rng_key, _rng = jax.random.split(rng_key) + # action = pi.sample(seed=_rng) + pi = eqx.filter_vmap(distributions.Categorical)(logits) rng_key, _rng = jax.random.split(rng_key) - action = pi.sample(seed=_rng) + __rng = jax.random.split(_rng, state.observation.shape[0]) + action = eqx.filter_vmap(lambda x, y: x.sample(y))(pi, __rng) + action = action.astype("int32") rng_key, _rng = jax.random.split(rng_key) keys = jax.random.split(_rng, state.observation.shape[0]) state = step_fn(state, action, keys) return state, R + state.rewards, rng_key + state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key)) return R.mean() @@ -306,13 +335,12 @@ def train(rng): st = time.time() # INIT NETWORK rng, _rng = jax.random.split(rng) - init_x = jnp.zeros((1, ) + env.observation_shape) - params = forward.init(_rng, init_x) - opt_state = optimizer.init(params=params) + model = ActorCritic(env.num_actions, _rng, "tanh") + opt_state = optimizer.init(params=eqx.filter(model, eqx.is_inexact_array)) # INIT UPDATE FUNCTION _update_step = make_update_fn() - jitted_update_step = jax.jit(_update_step) + jitted_update_step = eqx.filter_jit(_update_step) # INIT ENV rng, _rng = jax.random.split(rng) @@ -320,7 +348,7 @@ def train(rng): env_state = jax.jit(jax.vmap(env.init))(reset_rng) rng, _rng = jax.random.split(rng) - runner_state = (params, opt_state, env_state, env_state.observation, _rng) + runner_state = (model, opt_state, env_state, env_state.observation, _rng) # warm up _, _ = jitted_update_step(runner_state) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 404dbb471..871587745 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -3,8 +3,8 @@ pytest pytest-xdist matplotlib ipython -# hot fix. to avoid errors in Py3.8 dm-haiku==0.0.10 +equinox pytest-cov pgx-minatar black