Skip to content

Commit 094673e

Browse files
committed
feat: enhance HuggingfaceDatRepo and HuggingFaceTaskReader for improved dataset handling and proxy configuration
fix: update PeriodicDrainThreadPoolExecutor to manage task results and auto-retry functionality chore: modify example_math_swarm to use updated dataset path and configuration
1 parent 6fe101b commit 094673e

File tree

5 files changed

+67
-29
lines changed

5 files changed

+67
-29
lines changed

ajet/default_config/ajet_default.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ class AjetRollout:
3636
@dataclass
3737
class HuggingfaceDatRepo:
3838
dataset_path: str = "gsm8k"
39+
dataset_name: str | None = None
3940
training_split: str = "train"
4041
validation_split: str = "validation"
42+
http_proxy_address: str = ""
4143

4244

4345
@dataclass

ajet/task_reader/hf_dataset_reader.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11

22
import datasets
3+
import httpx
4+
import huggingface_hub
35

46
from ajet.schema.task import Task
57
from ajet.task_reader.task_reader_base import BaseTaskReader
@@ -17,7 +19,30 @@ def __init__(self, reader_config):
1719
super().__init__(reader_config)
1820
self.reader_config = reader_config
1921
self.as_generator = False
20-
self.dataset_name = self.reader_config.huggingface_dat_repo.dataset_path
22+
self.dataset_path = self.reader_config.huggingface_dat_repo.dataset_path
23+
24+
25+
try:
26+
self.dataset_name = self.reader_config.huggingface_dat_repo.dataset_name
27+
except Exception:
28+
self.dataset_name = None
29+
30+
try:
31+
self.http_proxy_address = getattr(
32+
self.reader_config.huggingface_dat_repo, "http_proxy_address", ""
33+
) or getattr(self.reader_config.huggingface_dat_repo, "http_proxy", "")
34+
except Exception:
35+
self.http_proxy_address = ""
36+
37+
# Configure httpx proxy via set_client_factory (replaces deprecated proxies= arg)
38+
if self.http_proxy_address:
39+
proxy_url = self.http_proxy_address
40+
huggingface_hub.set_client_factory(
41+
lambda **kwargs: httpx.Client(
42+
proxy=proxy_url,
43+
**{k: v for k, v in kwargs.items() if k != "proxies"},
44+
)
45+
)
2146

2247
def _load_dataset_split(self, split: str):
2348
"""
@@ -30,22 +55,26 @@ def _load_dataset_split(self, split: str):
3055
Generator: List of Task objects created from the dataset.
3156
"""
3257
try:
33-
if self.dataset_name.endswith(".parquet"):
58+
59+
if self.dataset_path.endswith(".parquet"):
3460
# Load from local parquet file
35-
dataset = datasets.load_dataset("parquet", data_files=self.dataset_name, split=split)
61+
dataset = datasets.load_dataset(
62+
"parquet", data_files=self.dataset_path, split=split
63+
)
3664
else:
37-
# Load from Hugging Face hub
38-
dataset = datasets.load_dataset(self.dataset_name, split=split)
65+
dataset = datasets.load_dataset(
66+
self.dataset_path, split=split, name=self.dataset_name
67+
)
3968
# shuffle dataset
4069
dataset = dataset.map(lambda example, idx: {"original_idx": idx}, with_indices=True)
4170
dataset = dataset.shuffle()
4271
except Exception as e:
4372
raise ValueError(
44-
f"Failed to load dataset '{self.dataset_name}' with split '{split}': {str(e)}"
73+
f"Failed to load dataset '{self.dataset_path}' with split '{split}': {str(e)}"
4574
)
4675

4776
if len(dataset) == 0:
48-
raise ValueError(f"No examples found in dataset '{self.dataset_name}' with split '{split}'")
77+
raise ValueError(f"No examples found in dataset '{self.dataset_path}' with split '{split}'")
4978

5079
self.as_generator = True
5180

ajet/tuner_lib/experimental/as_swarm_client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import httpx
55
import json
6+
import re
67
import yaml
78
from beast_logger import print_dict
89
from typing import List, Tuple
@@ -60,7 +61,7 @@ def __init__(self, server_url: str):
6061

6162
# better logging management
6263
self._last_second_print_buffer: dict[str, float] = {}
63-
self.begin_episode_lock = threading.Lock()
64+
self._begin_episode_lock = threading.Lock()
6465
# record last registered AgentJetJob
6566
self._agent_jet_job = None
6667
# throttle
@@ -202,7 +203,9 @@ def begin_episode(self, discard_episode_timeout=600, episode_type="train", throt
202203
Return:
203204
(episode_uuid, openai_base_url, openai_api_key)
204205
"""
206+
return self._begin_episode_auto_repeat(discard_episode_timeout, episode_type, throttle_policy)
205207

208+
def _begin_episode_auto_repeat(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
206209
# max_episode_time: when an episode has **lasted** for more than X seconds, it will be terminated **locally** by client (call `end_episode` will be re-route to `abort_episode`)
207210
max_episode_time = 2*discard_episode_timeout
208211

@@ -225,7 +228,7 @@ def begin_episode(self, discard_episode_timeout=600, episode_type="train", throt
225228

226229
# when throttle_policy is set, acquire lock to prevent multiple threads from claiming episode at the same time and causing throttle policy to fail
227230
if throttle_policy is not None:
228-
self.begin_episode_lock.acquire()
231+
self._begin_episode_lock.acquire()
229232

230233
try:
231234
# Check throttle policy before claiming episode (only for train episodes)
@@ -259,6 +262,10 @@ def begin_episode(self, discard_episode_timeout=600, episode_type="train", throt
259262
episode_uuid = data.episode_uuid
260263
openai_base_url = data.openai_base_url
261264
openai_api_key = data.openai_api_key
265+
266+
# force replace openai_base_url host with self.server_url
267+
openai_base_url = re.sub(r'^https?://[^/]+', self.server_url, openai_base_url)
268+
262269
self.logger_info(f"Claimed episode {episode_uuid}, current global step: {status_json.get('global_step', 'unknown')}")
263270
return episode_uuid, OpenaiBaseUrlAndApiKey(
264271
base_url=openai_base_url,
@@ -290,8 +297,8 @@ def begin_episode(self, discard_episode_timeout=600, episode_type="train", throt
290297

291298
finally:
292299
if throttle_policy is not None:
293-
if self.begin_episode_lock.locked():
294-
self.begin_episode_lock.release()
300+
if self._begin_episode_lock.locked():
301+
self._begin_episode_lock.release()
295302

296303
def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOutput):
297304

ajet/utils/thread_executors.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,15 @@ def __init__(self, workers=100, auto_retry=True):
5050
self._executor = ThreadPoolExecutor(max_workers=workers)
5151
self._submitted_count = 0
5252
self._auto_retry = auto_retry
53+
self.current_futures = []
5354

5455
def submit(self, fn, *args, **kwargs):
5556
"""Submit a task, blocking if the pending queue is full."""
5657

57-
def retry_wrapper(func, arg):
58+
def retry_wrapper(fn, *args, **kwargs):
5859
while True:
5960
try:
60-
return func(arg)
61+
return fn(*args, **kwargs)
6162
except Exception as e:
6263
logger.exception(f"[run_episodes_until_all_complete] Error executing episode: {e}. Retrying...")
6364

@@ -69,12 +70,19 @@ def retry_wrapper(func, arg):
6970
def submit_with_periodic_drain(self, fn, *args, **kwargs):
7071
"""Submit a task, draining all in-flight work every `drain_every_n_job` submissions."""
7172
drain_every_n_job = self._max_workers
73+
results = []
7274
if self._submitted_count > 0 and self._submitted_count % drain_every_n_job == 0:
73-
self._executor.shutdown(wait=True)
74-
self._executor = ThreadPoolExecutor(max_workers=self._max_workers)
75+
for future in self.current_futures:
76+
try:
77+
results += [future.result()] # Wait for the task to complete and raise exceptions if any
78+
except Exception as e:
79+
logger.exception(f"Error in task execution: {e}")
80+
self.current_futures = []
7581

7682
self._submitted_count += 1
77-
return self.submit(fn, *args, **kwargs)
83+
future = self.submit(fn, *args, **kwargs)
84+
self.current_futures.append(future)
85+
return future, results
7886

7987
def shutdown(self, wait=True):
8088
"""Shut down the underlying executor."""

tutorial/example_math_swarm/math.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818

1919
GRPO_N = 4 # grpo group size
2020
NUM_EPOCH = 10000
21-
DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main"
2221
AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086")
2322

2423
REMOTE_BATCH_SIZE = 32
25-
REMOTE_ALLOCATE_GPU_PER_NODE = 4
24+
REMOTE_ALLOCATE_GPU_PER_NODE = 8
2625
REMOTE_TRAIN_MODEL = '/root/agentjet/modelscope_cache/Qwen/Qwen2.5-7B-Instruct'
2726

2827
def main():
@@ -32,7 +31,9 @@ def main():
3231
reader_type = "huggingface_dat_repo",
3332
reader_config = AjetTaskReader(
3433
huggingface_dat_repo = HuggingfaceDatRepo(
35-
dataset_path = DATASET_PATH
34+
dataset_path = "C:/Users/fuqingxu-hub/Downloads/dataset/gsm8k/socratic",
35+
# dataset_path = "openai/gsm8k",
36+
# dataset_name = "main",
3637
)
3738
)
3839
)
@@ -53,20 +54,11 @@ def main():
5354
def rollout(task):
5455
try:
5556
# begin episode
56-
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(
57-
throttle_policy=SwarmThrottlePolicy(
58-
ratio=0.5,
59-
expected_batch_size=REMOTE_BATCH_SIZE,
60-
expected_num_repeat=GRPO_N,
61-
current_task_id=task.task_id
62-
)
63-
)
57+
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60)
6458
# execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
6559
workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
6660
# report output back to swarm remote
6761
swarm_worker.end_episode(task, episode_uuid, workflow_output)
68-
# print global rollout status across the swarm
69-
swarm_worker.print_rollout_stat()
7062
return
7163
except:
7264
pass

0 commit comments

Comments
 (0)