Skip to content

Commit 4beb3ad

Browse files
renamed lib
1 parent e2d2770 commit 4beb3ad

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

open_instruct/data_loader.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from transformers import PreTrainedTokenizer
3535

3636
from open_instruct import data_types, padding_free_collator, utils
37-
from open_instruct import replay_buffer as replay_buffer_mod
37+
from open_instruct import replay_buffer as replay_buffer_lib
3838
from open_instruct.data_types import EnvConfig, EnvConfigEntry
3939
from open_instruct.dataset_transformation import (
4040
ENV_CONFIG_KEY,
@@ -487,14 +487,14 @@ class StreamingDataLoaderConfig:
487487
save_traces: bool = False
488488
rollouts_save_path: str = "/weka/oe-adapt-default/allennlp/deletable_rollouts/"
489489

490-
replay_buffer: replay_buffer_mod.ReplayBufferConfig = field(default_factory=replay_buffer_mod.ReplayBufferConfig)
490+
replay_buffer: replay_buffer_lib.ReplayBufferConfig = field(default_factory=replay_buffer_lib.ReplayBufferConfig)
491491

492492
# Computed at post_init
493493
max_possible_score: float = 1.0
494494

495495
def __post_init__(self):
496496
if isinstance(self.replay_buffer, dict):
497-
self.replay_buffer = replay_buffer_mod.ReplayBufferConfig(**self.replay_buffer)
497+
self.replay_buffer = replay_buffer_lib.ReplayBufferConfig(**self.replay_buffer)
498498
assert self.pack_length >= self.max_prompt_token_length + self.response_length, (
499499
"The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!"
500500
)
@@ -745,7 +745,7 @@ def process_single_result(
745745
replenish_prompts: bool,
746746
param_prompt_Q: ray_queue.Queue | None,
747747
base_env_config: EnvConfig,
748-
) -> replay_buffer_mod.ProcessedResult | None:
748+
) -> replay_buffer_lib.ProcessedResult | None:
749749
assert result.index is not None
750750
assert result.logprobs is not None
751751
assert result.reward_scores is not None
@@ -793,7 +793,7 @@ def process_single_result(
793793
if filter_zero_std_samples and np.std(result.reward_scores) == 0:
794794
return None
795795

796-
return replay_buffer_mod.ProcessedResult(
796+
return replay_buffer_lib.ProcessedResult(
797797
result=result,
798798
queries=repeat_each([query], generation_config.n),
799799
ground_truths=repeat_each([ground_truth], generation_config.n),
@@ -808,7 +808,7 @@ def process_single_result(
808808

809809

810810
def combine_processed_results(
811-
processed_results: list[replay_buffer_mod.ProcessedResult],
811+
processed_results: list[replay_buffer_lib.ProcessedResult],
812812
generation_config: vllm.SamplingParams,
813813
actor_manager=None,
814814
) -> tuple[data_types.GenerationResult, Batch, dict, BatchStatistics] | tuple[None, None, None, None]:
@@ -975,7 +975,7 @@ def accumulate_inference_batches(
975975
"replenish_prompts requires param_prompt_Q and iter_dataloader and dataset"
976976
)
977977

978-
processed_results: list[replay_buffer_mod.ProcessedResult] = []
978+
processed_results: list[replay_buffer_lib.ProcessedResult] = []
979979
total_filtered_prompts = 0
980980
filtered_prompt_zero = 0
981981
filtered_prompt_solved = 0
@@ -1239,12 +1239,12 @@ def __init__(
12391239
self.metadata_saved = False
12401240

12411241
rb = config.replay_buffer
1242-
self._table = replay_buffer_mod.Table(
1242+
self._table = replay_buffer_lib.Table(
12431243
max_size=rb.capacity or global_batch_size,
1244-
sampler=replay_buffer_mod.make_selector(rb.sampler),
1245-
remover=replay_buffer_mod.make_selector(rb.remover),
1244+
sampler=replay_buffer_lib.make_selector(rb.sampler),
1245+
remover=replay_buffer_lib.make_selector(rb.remover),
12461246
max_times_sampled=rb.max_times_sampled,
1247-
rate_limiter=replay_buffer_mod.MinSize(rb.min_size or global_batch_size),
1247+
rate_limiter=replay_buffer_lib.MinSize(rb.min_size or global_batch_size),
12481248
)
12491249

12501250
if initial_state is not None:

0 commit comments

Comments
 (0)