Skip to content

Commit 4efab25

Browse files
Make AgenticGRPOLearner import optional.
PiperOrigin-RevId: 896861388
1 parent 49742a1 commit 4efab25

2 files changed

Lines changed: 10 additions & 1 deletion

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
from tunix.rl import rl_cluster as rl_cluster_lib
6868
from tunix.rl.rollout import base_rollout
6969
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
70-
from tunix.rl.agentic.agentic_grpo_learner import GrpoConfig as AgenticGrpoConfig, GrpoLearner as AgenticGrpoLearner
7170
from tunix.sft import metrics_logger, profiler
7271

7372
# for vLLM we can skip JAX precompilation with this flag, it makes startup faster
@@ -601,6 +600,15 @@ def _reward_fn(**kwargs):
601600
max_logging.log("Setting up RL trainer...")
602601
if trainer_config.rl.use_agentic_rollout:
603602
max_logging.log("Using AgenticGRPOLearner with async online rollouts.")
603+
# TODO: Remove this try-except once the dependency on tunix is fixed.
604+
try:
605+
from tunix.rl.agentic.agentic_grpo_learner import GrpoConfig as AgenticGrpoConfig # pylint: disable=import-outside-toplevel
606+
from tunix.rl.agentic.agentic_grpo_learner import GrpoLearner as AgenticGrpoLearner # pylint: disable=import-outside-toplevel
607+
except ImportError as e:
608+
raise ValueError(
609+
"tunix.rl.agentic dependencies are not installed! "
610+
"Please install tunix with agentic support to use 'use_agentic_rollout'."
611+
) from e
604612
grpo_config = AgenticGrpoConfig(
605613
num_generations=trainer_config.rl.num_generations,
606614
num_iterations=trainer_config.rl.num_iterations,

tests/post_training/unit/train_rl_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files
319319
# Configs
320320
trainer_config = SimpleNamespace(
321321
debug=SimpleNamespace(rl=False),
322+
rl=SimpleNamespace(use_agentic_rollout=False),
322323
tokenizer_path="dummy_path",
323324
dataset_name="dummy_dataset",
324325
train_split="train",

0 commit comments

Comments
 (0)