Skip to content

Commit 3a92c53

Browse files
committed
rename source code
1 parent 563d570 commit 3a92c53

File tree

19 files changed

+50
-21
lines changed

19 files changed

+50
-21
lines changed

ajet/backbone/trainer_verl.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,35 @@ def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataP
144144
logger.info(f'task_id_counter: {task_id_counter}')
145145
return gen_batch_output
146146

147+
def import_or_export_data_proto(batch: DataProto, direction: str = "export", file: str = "./tmp.pkl") -> DataProto:
148+
"""Import or export a DataProto batch to/from a pickle file.
149+
150+
Args:
151+
batch: The DataProto batch object. Used when direction is "export";
152+
ignored (can be None) when direction is "import".
153+
direction: "import" to load a batch from file, "export" to save the batch to file.
154+
file: Path to the pickle file. Defaults to "./tmp.pkl".
155+
156+
Returns:
157+
The DataProto batch — either the one just loaded (import) or the one just saved (export).
158+
159+
Raises:
160+
ValueError: If direction is not "import" or "export".
161+
FileNotFoundError: If direction is "import" and the file does not exist.
162+
"""
163+
import pickle
164+
if direction == "export":
165+
with open(file, "wb") as f:
166+
pickle.dump(batch, f)
167+
logger.info(f"[import_or_export_data_proto] Exported batch to {file}")
168+
return batch
169+
elif direction == "import":
170+
with open(file, "rb") as f:
171+
batch = pickle.load(f)
172+
logger.info(f"[import_or_export_data_proto] Imported batch from {file}")
173+
return batch
174+
else:
175+
raise ValueError(f"direction must be 'import' or 'export', got '{direction}'")
147176

148177
def compute_advantage(
149178
data: DataProto,

ajet/copilot/job.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from typing import Any, Callable, Union, cast
1515
from loguru import logger
16-
from ajet.default_config.ajet_default import Config
16+
from ajet.default_config.ajet_config_schema import Config
1717
from ajet.utils.config_utils import (
1818
expand_ajet_hierarchical_config,
1919
read_ajet_hierarchical_config,
@@ -42,7 +42,7 @@ class AgentJetJob:
4242
"""Programmatic interface for configuring and launching AgentJet training jobs.
4343
4444
Args:
45-
base_yaml_config: Path to base YAML configuration file. If None, uses default config (at ./ajet/default_config/ajet_ts_default.yaml).
45+
base_yaml_config: Path to base YAML configuration file. If None, uses default config (at ./ajet/default_config/ajet_swarm_default.yaml).
4646
experiment_dir: Directory where experiment outputs will be saved.
4747
project_name: Name of the project for organizing experiments.
4848
experiment_name: Unique name for this specific experiment run.
@@ -86,7 +86,7 @@ def __init__(
8686
) -> None:
8787

8888
if base_yaml_config is None:
89-
base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))
89+
base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_swarm_default.yaml"))
9090
else:
9191
logger.warning(f"Reading config from {base_yaml_config}.")
9292
time.sleep(1)
@@ -121,7 +121,7 @@ def __init__(
121121
self.max_model_len: int = cast(int, max_model_len)
122122
self.mini_batch_num: int = cast(int, mini_batch_num)
123123

124-
# see `ajet/default_config/ajet_ts_default.yaml`
124+
# see `ajet/default_config/ajet_swarm_default.yaml`
125125
overrides = {
126126
# left: [yaml key navigation] right: [AgentJetJob self attr]
127127
"ajet.experiment_dir": "experiment_dir",

ajet/copilot/train-complex-blackbox/SKILL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ from ajet.copilot.job import AgentJetJob
5454
from ajet.task_reader import RouterTaskReader
5555
from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor
5656
from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
57-
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
57+
from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo
5858
from ajet.tuner_lib.experimental.swarm_client import SwarmClient
5959

6060
# python -m tutorial.example_math_swarm.math

ajet/copilot/write-swarm-client/SKILL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ Below are some reference materials.
365365
```python
366366
from ajet.copilot.job import AgentJetJob
367367
from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete
368-
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
368+
from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo
369369
from ajet.task_reader import RouterTaskReader
370370
from tutorial.example_academic_trans_swarm.trans import execute_agent
371371

File renamed without changes.

ajet/launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def main():
207207
if args.swarm_server and (not args.conf):
208208
args.conf = os.path.abspath(
209209
os.path.join(
210-
os.path.dirname(__file__), "default_config/ajet_ts_default.yaml"
210+
os.path.dirname(__file__), "default_config/ajet_swarm_default.yaml"
211211
)
212212
)
213213
assert os.path.exists(args.conf), (

ajet/swarm_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def cmd_start(args):
4646
if not args.conf:
4747
args.conf = os.path.abspath(
4848
os.path.join(
49-
os.path.dirname(__file__), "default_config/ajet_ts_default.yaml"
49+
os.path.dirname(__file__), "default_config/ajet_swarm_default.yaml"
5050
)
5151
)
5252
assert os.path.exists(args.conf), (

docs/en/swarm_best_practice.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ Hint: you do not have to use `run_episodes_until_all_complete`, you are free to
133133
```python
134134
from ajet.copilot.job import AgentJetJob
135135
from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete
136-
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
136+
from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo
137137
from ajet.task_reader import RouterTaskReader
138138
from tutorial.example_academic_trans_swarm.trans import execute_agent
139139

docs/en/tune_your_first_agent.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ Create your client script. The client reads the dataset, runs the agent workflow
496496
from ajet.task_reader import RouterTaskReader
497497
from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor
498498
from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
499-
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
499+
from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo
500500
from ajet.tuner_lib.experimental.swarm_client import SwarmClient
501501

502502
# Configuration
@@ -649,7 +649,7 @@ The server handles gradient computation and model updates automatically.
649649
from ajet.task_reader import RouterTaskReader
650650
from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor
651651
from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
652-
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
652+
from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo
653653
from ajet.tuner_lib.experimental.swarm_client import SwarmClient
654654

655655
GRPO_N = 4 # grpo group size

0 commit comments

Comments
 (0)