3434from transformers import PreTrainedTokenizer
3535
3636from open_instruct import data_types , padding_free_collator , utils
37- from open_instruct import replay_buffer as replay_buffer_mod
37+ from open_instruct import replay_buffer as replay_buffer_lib
3838from open_instruct .data_types import EnvConfig , EnvConfigEntry
3939from open_instruct .dataset_transformation import (
4040 ENV_CONFIG_KEY ,
@@ -487,14 +487,14 @@ class StreamingDataLoaderConfig:
487487 save_traces : bool = False
488488 rollouts_save_path : str = "/weka/oe-adapt-default/allennlp/deletable_rollouts/"
489489
490- replay_buffer : replay_buffer_mod .ReplayBufferConfig = field (default_factory = replay_buffer_mod .ReplayBufferConfig )
490+ replay_buffer : replay_buffer_lib .ReplayBufferConfig = field (default_factory = replay_buffer_lib .ReplayBufferConfig )
491491
492492 # Computed at post_init
493493 max_possible_score : float = 1.0
494494
495495 def __post_init__ (self ):
496496 if isinstance (self .replay_buffer , dict ):
497- self .replay_buffer = replay_buffer_mod .ReplayBufferConfig (** self .replay_buffer )
497+ self .replay_buffer = replay_buffer_lib .ReplayBufferConfig (** self .replay_buffer )
498498 assert self .pack_length >= self .max_prompt_token_length + self .response_length , (
499499 "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!"
500500 )
@@ -745,7 +745,7 @@ def process_single_result(
745745 replenish_prompts : bool ,
746746 param_prompt_Q : ray_queue .Queue | None ,
747747 base_env_config : EnvConfig ,
748- ) -> replay_buffer_mod .ProcessedResult | None :
748+ ) -> replay_buffer_lib .ProcessedResult | None :
749749 assert result .index is not None
750750 assert result .logprobs is not None
751751 assert result .reward_scores is not None
@@ -793,7 +793,7 @@ def process_single_result(
793793 if filter_zero_std_samples and np .std (result .reward_scores ) == 0 :
794794 return None
795795
796- return replay_buffer_mod .ProcessedResult (
796+ return replay_buffer_lib .ProcessedResult (
797797 result = result ,
798798 queries = repeat_each ([query ], generation_config .n ),
799799 ground_truths = repeat_each ([ground_truth ], generation_config .n ),
@@ -808,7 +808,7 @@ def process_single_result(
808808
809809
810810def combine_processed_results (
811- processed_results : list [replay_buffer_mod .ProcessedResult ],
811+ processed_results : list [replay_buffer_lib .ProcessedResult ],
812812 generation_config : vllm .SamplingParams ,
813813 actor_manager = None ,
814814) -> tuple [data_types .GenerationResult , Batch , dict , BatchStatistics ] | tuple [None , None , None , None ]:
@@ -975,7 +975,7 @@ def accumulate_inference_batches(
975975 "replenish_prompts requires param_prompt_Q and iter_dataloader and dataset"
976976 )
977977
978- processed_results : list [replay_buffer_mod .ProcessedResult ] = []
978+ processed_results : list [replay_buffer_lib .ProcessedResult ] = []
979979 total_filtered_prompts = 0
980980 filtered_prompt_zero = 0
981981 filtered_prompt_solved = 0
@@ -1239,12 +1239,12 @@ def __init__(
12391239 self .metadata_saved = False
12401240
12411241 rb = config .replay_buffer
1242- self ._table = replay_buffer_mod .Table (
1242+ self ._table = replay_buffer_lib .Table (
12431243 max_size = rb .capacity or global_batch_size ,
1244- sampler = replay_buffer_mod .make_selector (rb .sampler ),
1245- remover = replay_buffer_mod .make_selector (rb .remover ),
1244+ sampler = replay_buffer_lib .make_selector (rb .sampler ),
1245+ remover = replay_buffer_lib .make_selector (rb .remover ),
12461246 max_times_sampled = rb .max_times_sampled ,
1247- rate_limiter = replay_buffer_mod .MinSize (rb .min_size or global_batch_size ),
1247+ rate_limiter = replay_buffer_lib .MinSize (rb .min_size or global_batch_size ),
12481248 )
12491249
12501250 if initial_state is not None :
0 commit comments