diff --git a/src/maxtext/experimental/rl/grpo_utils.py b/src/maxtext/experimental/rl/grpo_utils.py index 8989405eab..b1e9bec041 100644 --- a/src/maxtext/experimental/rl/grpo_utils.py +++ b/src/maxtext/experimental/rl/grpo_utils.py @@ -31,7 +31,7 @@ from pathwaysutils.experimental import split_by_mesh_axis -def _identity(x): +def _identity(x: Any) -> Any: return x @@ -40,16 +40,16 @@ def _identity(x): def compute_log_probs( - model, - params, - inputs, - inputs_position, - inputs_segmentation, - completion_segmentation, - config, - is_train=False, - rngs=None, -): + model: Any, + params: Any, + inputs: jax.Array, + inputs_position: jax.Array, + inputs_segmentation: jax.Array, + completion_segmentation: jax.Array, + config: Any, + is_train: bool = False, + rngs: Any | None = None, +) -> tuple[jax.Array, dict[str, Any]]: """Computes per-token log-probabilities for a sequence of tokens. This helper calls model.apply (with dropout enabled if is_train) to obtain @@ -114,14 +114,14 @@ def compute_log_probs( def compute_log_probs_nnx( - model, - inputs, - inputs_position, - inputs_segmentation, - completion_segmentation, - config, - is_train=False, -): + model: nnx.Module, + inputs: jax.Array, + inputs_position: jax.Array, + inputs_segmentation: jax.Array, + completion_segmentation: jax.Array, + config: Any, + is_train: bool = False, +) -> tuple[jax.Array, dict[str, Any]]: """Compute per-token log-probabilities for an NNX policy. Mirrors `compute_log_probs` but takes an `nnx.Module` directly: the model @@ -170,7 +170,12 @@ def compute_log_probs_nnx( return token_log_probs, intermediate_outputs -def generate_offline_completions(config, tokenizer_model, inference_engine, data): +def generate_offline_completions( + config: Any, + tokenizer_model: Any, + inference_engine: Any, + data: dict[str, Any], +) -> dict[str, Any]: """Generates completions for a batch of prompts using an offline engine. Args: @@ -238,8 +243,12 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data def pathways_reshard_nnx( - config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model -): + config: Any, + inference_engine: Any, + policy_state_model: Any, + source_shardings_model: Any, + destination_shardings_model: Any, +) -> None: """Reshard NNX policy params onto the inference mesh. Splits the policy `nnx.Param` state out of the training-side TrainStateNNX @@ -271,7 +280,14 @@ def pathways_reshard_nnx( inference_engine.update_params(resharded_params) -def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings): +def pathways_reshard( + config: Any, + inference_engine: Any, + params: Any, + source_shardings: Any, + source_mesh: jax.sharding.Mesh, + destination_shardings: Any, +) -> None: """Reshards model parameters from training to inference sharding. This function handles the resharding of parameters between different device @@ -314,7 +330,7 @@ def pathways_reshard(config, inference_engine, params, source_shardings, source_ inference_engine.update_params(resharded_params) -def dummy_reward_len(valid_seq_mask): +def dummy_reward_len(valid_seq_mask: jax.Array) -> jax.Array: """Calculates a dummy reward based on the length of the valid sequence. Args: @@ -330,7 +346,12 @@ def dummy_reward_len(valid_seq_mask): return reward -def concatenate_prompt_with_completions(config, tokenizer_model, data, completions): +def concatenate_prompt_with_completions( + config: Any, + tokenizer_model: Any, + data: dict[str, Any], + completions: jax.Array, +) -> dict[str, Any]: """Concatenates prompts with their generated completions. This function takes a batch of prompts and a corresponding batch of @@ -392,7 +413,11 @@ def _concat_and_find_eos(prompt, true_len, completion): return data -def pad_or_trim(arr, max_target_length, pad_token): +def pad_or_trim( + arr: list[np.ndarray], + max_target_length: int, + pad_token: int | float, +) -> np.ndarray: """Pads or trims a list of sequences to a maximum target length. Args: @@ -412,7 +437,12 @@ def pad_or_trim(arr, max_target_length, pad_token): return padded -def filter_and_split(config, example_batch, num_groups, global_batch_size_per_group): +def filter_and_split( + config: Any, + example_batch: dict[str, Any], + num_groups: int, + global_batch_size_per_group: int, +) -> list[dict[str, Any]]: """Splits an example_batch into a list of smaller batches. Samples are taken from the beginning of the input batch, and extras are @@ -466,7 +496,10 @@ def filter_and_split(config, example_batch, num_groups, global_batch_size_per_gr return list_of_output_batches -def _maybe_find_intermediate_sharding(source_sharding, target_sharding): +def _maybe_find_intermediate_sharding( + source_sharding: jax.sharding.Sharding, + target_sharding: jax.sharding.Sharding, +) -> jax.sharding.NamedSharding | None: """Maybe finds an intermediate sharding to reshard to before target sharding.""" if not isinstance(source_sharding, jax.sharding.NamedSharding) or not isinstance( target_sharding, jax.sharding.NamedSharding @@ -478,7 +511,10 @@ def _maybe_find_intermediate_sharding(source_sharding, target_sharding): src_mesh = source_sharding.mesh dst_mesh = target_sharding.mesh - def _get_sharding_dims(sharding, mesh): + def _get_sharding_dims( + sharding: jax.sharding.NamedSharding, + mesh: jax.sharding.Mesh, + ) -> tuple[dict[Any, int], int]: sharding_dims = {} used_mesh_axis_names = set() for i, axis_name in enumerate(sharding.spec): @@ -565,7 +601,11 @@ def _get_sharding_dims(sharding, mesh): return intermediate_sharding -def _experimental_pre_reshard(splitfn, src_pytree, target_shardings): +def _experimental_pre_reshard( + splitfn: Callable[..., Any], + src_pytree: Any, + target_shardings: Any, +) -> Any: """Simple heuristic to determine if resharding with replicated all-gather is needed.""" src_shardings = jax.tree_util.tree_map( lambda x: x.sharding, @@ -625,13 +665,13 @@ def _get_reshard_fn_pathwaysutils( cache_resharding_plans: bool, donate: bool, use_experimental_pre_reshard: bool, -): +) -> Callable[..., Any]: """Returns a reshard function using pathwaysutils.""" def reshard_fn( x: Any, sharding: jax.sharding.Sharding | Any, - ): + ) -> Any: if use_experimental_pre_reshard: x = _experimental_pre_reshard(split_by_mesh_axis.split_by_mesh_axis, x, sharding) @@ -650,7 +690,7 @@ def _get_reshard_fn( donate: bool, use_experimental_pre_reshard: bool, get_reshard_fns: list[Callable[..., Any]], -): +) -> Callable[..., Any]: """Returns a reshard function.""" for get_reshard_fn in get_reshard_fns: try: