@@ -779,7 +779,7 @@ def init_dynamo():
779779
780780 # if it is not colocated inference, initialize collective communication for update weights
781781 # Dynamo backend does not support weight updates — skip collective init and refit.
782- if not colocated_inference and backend != "dynamo" :
782+ if not colocated_inference and backend not in ( "dynamo" , "vllm" ) :
783783 t0 = time .perf_counter ()
784784 ip , port = train_cluster .get_master_address_and_port ()
785785 print (f"Using ip: { ip } , port: { port } for collective communication" , flush = True )
@@ -800,7 +800,7 @@ def init_dynamo():
800800
801801 # prepare refit info
802802 state_dict_info = policy .prepare_refit_info ()
803- if policy_generation is not None and backend != "dynamo" :
803+ if policy_generation is not None and backend not in ( "dynamo" , "vllm" ) :
804804 policy_generation .prepare_refit_info (state_dict_info )
805805
806806 # Calculate total setup time
@@ -1393,7 +1393,7 @@ def grpo_train(
13931393 if policy_generation is None :
13941394 policy_generation = policy # type: ignore
13951395 NEED_REFIT = False
1396- elif master_config ["policy" ]["generation" ]["backend" ] == "dynamo" :
1396+ elif master_config ["policy" ]["generation" ]["backend" ] in ( "dynamo" , "vllm" ) :
13971397 NEED_REFIT = False
13981398 POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
13991399 assert policy_generation is not None # for mypy type check
@@ -1579,7 +1579,9 @@ def grpo_train(
15791579 input_batch = repeated_batch ,
15801580 tokenizer = tokenizer ,
15811581 task_to_env = task_to_env ,
1582- max_seq_len = None ,
1582+ max_seq_len = master_config ["policy" ][
1583+ "max_total_sequence_length"
1584+ ],
15831585 generation_config = generation_config ,
15841586 max_rollout_turns = None ,
15851587 greedy = False ,
@@ -2316,7 +2318,7 @@ def validate(
23162318 input_batch = val_batch ,
23172319 tokenizer = tokenizer ,
23182320 task_to_env = val_task_to_env ,
2319- max_seq_len = None ,
2321+ max_seq_len = master_config [ "policy" ][ "max_total_sequence_length" ] ,
23202322 generation_config = generation_config ,
23212323 max_rollout_turns = None ,
23222324 greedy = False ,
@@ -2489,7 +2491,7 @@ def async_grpo_train(
24892491 if policy_generation is None :
24902492 policy_generation = policy
24912493 NEED_REFIT = False
2492- elif master_config ["policy" ]["generation" ]["backend" ] == "dynamo" :
2494+ elif master_config ["policy" ]["generation" ]["backend" ] in ( "dynamo" , "vllm" ) :
24932495 NEED_REFIT = False
24942496 POLICY_GENERATION_STALE = True
24952497 assert policy_generation is not None
@@ -2950,6 +2952,11 @@ def async_grpo_train(
29502952 weight_version += 1
29512953 trajectory_collector .set_weight_version .remote (weight_version )
29522954 trajectory_collector .resume_after_refit .remote ()
2955+ else :
2956+ # Advance the trajectory collector's weight version even when refit is skipped
2957+ # so that the replay buffer can sample trajectories targeted for subsequent steps.
2958+ weight_version += 1
2959+ trajectory_collector .set_weight_version .remote (weight_version )
29532960
29542961 # Clear logger metrics after each refit (weight sync), starting a new logging cycle
29552962 if policy_generation is not None :
0 commit comments