Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 73 additions & 33 deletions src/maxtext/experimental/rl/grpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pathwaysutils.experimental import split_by_mesh_axis


def _identity(x):
def _identity(x: Any) -> Any:
return x


Expand All @@ -40,16 +40,16 @@ def _identity(x):


def compute_log_probs(
model,
params,
inputs,
inputs_position,
inputs_segmentation,
completion_segmentation,
config,
is_train=False,
rngs=None,
):
model: Any,
params: Any,
inputs: jax.Array,
inputs_position: jax.Array,
inputs_segmentation: jax.Array,
completion_segmentation: jax.Array,
config: Any,
is_train: bool = False,
rngs: Any | None = None,
) -> tuple[jax.Array, dict[str, Any]]:
"""Computes per-token log-probabilities for a sequence of tokens.

This helper calls model.apply (with dropout enabled if is_train) to obtain
Expand Down Expand Up @@ -114,14 +114,14 @@ def compute_log_probs(


def compute_log_probs_nnx(
model,
inputs,
inputs_position,
inputs_segmentation,
completion_segmentation,
config,
is_train=False,
):
model: nnx.Module,
inputs: jax.Array,
inputs_position: jax.Array,
inputs_segmentation: jax.Array,
completion_segmentation: jax.Array,
config: Any,
is_train: bool = False,
) -> tuple[jax.Array, dict[str, Any]]:
"""Compute per-token log-probabilities for an NNX policy.

Mirrors `compute_log_probs` but takes an `nnx.Module` directly: the model
Expand Down Expand Up @@ -170,7 +170,12 @@ def compute_log_probs_nnx(
return token_log_probs, intermediate_outputs


def generate_offline_completions(config, tokenizer_model, inference_engine, data):
def generate_offline_completions(
config: Any,
tokenizer_model: Any,
inference_engine: Any,
data: dict[str, Any],
) -> dict[str, Any]:
"""Generates completions for a batch of prompts using an offline engine.

Args:
Expand Down Expand Up @@ -238,8 +243,12 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data


def pathways_reshard_nnx(
config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model
):
config: Any,
inference_engine: Any,
policy_state_model: Any,
source_shardings_model: Any,
destination_shardings_model: Any,
) -> None:
"""Reshard NNX policy params onto the inference mesh.

Splits the policy `nnx.Param` state out of the training-side TrainStateNNX
Expand Down Expand Up @@ -271,7 +280,14 @@ def pathways_reshard_nnx(
inference_engine.update_params(resharded_params)


def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings):
def pathways_reshard(
config: Any,
inference_engine: Any,
params: Any,
source_shardings: Any,
source_mesh: jax.sharding.Mesh,
destination_shardings: Any,
) -> None:
"""Reshards model parameters from training to inference sharding.

This function handles the resharding of parameters between different device
Expand Down Expand Up @@ -314,7 +330,7 @@ def pathways_reshard(config, inference_engine, params, source_shardings, source_
inference_engine.update_params(resharded_params)


def dummy_reward_len(valid_seq_mask):
def dummy_reward_len(valid_seq_mask: jax.Array) -> jax.Array:
"""Calculates a dummy reward based on the length of the valid sequence.

Args:
Expand All @@ -330,7 +346,12 @@ def dummy_reward_len(valid_seq_mask):
return reward


def concatenate_prompt_with_completions(config, tokenizer_model, data, completions):
def concatenate_prompt_with_completions(
config: Any,
tokenizer_model: Any,
data: dict[str, Any],
completions: jax.Array,
) -> dict[str, Any]:
"""Concatenates prompts with their generated completions.

This function takes a batch of prompts and a corresponding batch of
Expand Down Expand Up @@ -392,7 +413,11 @@ def _concat_and_find_eos(prompt, true_len, completion):
return data


def pad_or_trim(arr, max_target_length, pad_token):
def pad_or_trim(
arr: list[np.ndarray],
max_target_length: int,
pad_token: int | float,
) -> np.ndarray:
"""Pads or trims a list of sequences to a maximum target length.

Args:
Expand All @@ -412,7 +437,12 @@ def pad_or_trim(arr, max_target_length, pad_token):
return padded


def filter_and_split(config, example_batch, num_groups, global_batch_size_per_group):
def filter_and_split(
config: Any,
example_batch: dict[str, Any],
num_groups: int,
global_batch_size_per_group: int,
) -> list[dict[str, Any]]:
"""Splits an example_batch into a list of smaller batches.

Samples are taken from the beginning of the input batch, and extras are
Expand Down Expand Up @@ -466,7 +496,10 @@ def filter_and_split(config, example_batch, num_groups, global_batch_size_per_gr
return list_of_output_batches


def _maybe_find_intermediate_sharding(source_sharding, target_sharding):
def _maybe_find_intermediate_sharding(
source_sharding: jax.sharding.Sharding,
target_sharding: jax.sharding.Sharding,
) -> jax.sharding.NamedSharding | None:
"""Maybe finds an intermediate sharding to reshard to before target sharding."""
if not isinstance(source_sharding, jax.sharding.NamedSharding) or not isinstance(
target_sharding, jax.sharding.NamedSharding
Expand All @@ -478,7 +511,10 @@ def _maybe_find_intermediate_sharding(source_sharding, target_sharding):
src_mesh = source_sharding.mesh
dst_mesh = target_sharding.mesh

def _get_sharding_dims(sharding, mesh):
def _get_sharding_dims(
sharding: jax.sharding.NamedSharding,
mesh: jax.sharding.Mesh,
) -> tuple[dict[Any, int], int]:
sharding_dims = {}
used_mesh_axis_names = set()
for i, axis_name in enumerate(sharding.spec):
Expand Down Expand Up @@ -565,7 +601,11 @@ def _get_sharding_dims(sharding, mesh):
return intermediate_sharding


def _experimental_pre_reshard(splitfn, src_pytree, target_shardings):
def _experimental_pre_reshard(
splitfn: Callable[..., Any],
src_pytree: Any,
target_shardings: Any,
) -> Any:
"""Simple heuristic to determine if resharding with replicated all-gather is needed."""
src_shardings = jax.tree_util.tree_map(
lambda x: x.sharding,
Expand Down Expand Up @@ -625,13 +665,13 @@ def _get_reshard_fn_pathwaysutils(
cache_resharding_plans: bool,
donate: bool,
use_experimental_pre_reshard: bool,
):
) -> Callable[..., Any]:
"""Returns a reshard function using pathwaysutils."""

def reshard_fn(
x: Any,
sharding: jax.sharding.Sharding | Any,
):
) -> Any:
if use_experimental_pre_reshard:
x = _experimental_pre_reshard(split_by_mesh_axis.split_by_mesh_axis, x, sharding)

Expand All @@ -650,7 +690,7 @@ def _get_reshard_fn(
donate: bool,
use_experimental_pre_reshard: bool,
get_reshard_fns: list[Callable[..., Any]],
):
) -> Callable[..., Any]:
"""Returns a reshard function."""
for get_reshard_fn in get_reshard_fns:
try:
Expand Down