Skip to content

chore: add type annotations to grpo_utils based on docstrings#4076

Open
harshaa765 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
harshaa765:feat/type-annotate-grpo-utils
Open

chore: add type annotations to grpo_utils based on docstrings#4076
harshaa765 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
harshaa765:feat/type-annotate-grpo-utils

Conversation

@harshaa765

Copy link
Copy Markdown

Summary

Adds Python type annotations to all previously unannotated function signatures in src/maxtext/experimental/rl/grpo_utils.py.

  • Core log-prob helpers (compute_log_probs, compute_log_probs_nnx): inputs typed as jax.Array, returns as tuple[jax.Array, dict[str, Any]]
  • NNX policy helper (compute_log_probs_nnx): model typed as nnx.Module
  • Resharding helpers (pathways_reshard, pathways_reshard_nnx, _maybe_find_intermediate_sharding, reshard_pytree-path): shardings typed with jax.sharding.{Mesh,Sharding,NamedSharding}, functions returning None or Callable[..., Any] made explicit
  • Data pipeline helpers (pad_or_trim, filter_and_split, generate_offline_completions): dicts typed as dict[str, Any], sequences as list[np.ndarray], scalars as int | float
  • Nested helpers (_get_sharding_dims, reshard_fn): also annotated inline

No new dependencies introduced — annotations use only jax.Array, nnx.Module, jax.sharding.*, np.ndarray, and standard typing primitives (Any, Callable). The existing jaxtyping.PyTree annotations on reshard_pytree are left unchanged.

Test plan

  • ast.parse syntax check passes locally.
  • No runtime logic changed — annotations only.
  • Existing unit tests in tests/unit/grpo_nnx_test.py continue to cover the annotated functions.

Annotates all previously unannotated function signatures in
experimental/rl/grpo_utils.py using jax.Array, nnx.Module,
jax.sharding.{Mesh,Sharding,NamedSharding}, np.ndarray, and standard
Python typing (Any, Callable, tuple, list, dict) — no new dependencies.

- Core log-prob helpers: inputs typed as jax.Array, returns tuple[jax.Array, dict[str, Any]]
- Resharding helpers: shardings typed with jax.sharding.*, returns None or Callable[..., Any]
- Data pipeline helpers: dicts typed as dict[str, Any], sequences as list[np.ndarray]
- Nested helpers (_get_sharding_dims, reshard_fn) also annotated
- Existing jaxtyping.PyTree annotations on reshard_pytree left unchanged
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant