|
13 | 13 |
|
14 | 14 | from typing import Any, Callable, Union, cast |
15 | 15 | from loguru import logger |
16 | | -from ajet.default_config.ajet_default import Config |
| 16 | +from ajet.default_config.ajet_config_schema import Config |
17 | 17 | from ajet.utils.config_utils import ( |
18 | 18 | expand_ajet_hierarchical_config, |
19 | 19 | read_ajet_hierarchical_config, |
@@ -42,7 +42,7 @@ class AgentJetJob: |
42 | 42 | """Programmatic interface for configuring and launching AgentJet training jobs. |
43 | 43 |
|
44 | 44 | Args: |
45 | | - base_yaml_config: Path to base YAML configuration file. If None, uses default config (at ./ajet/default_config/ajet_ts_default.yaml). |
| 45 | + base_yaml_config: Path to base YAML configuration file. If None, uses default config (at ./ajet/default_config/ajet_swarm_default.yaml). |
46 | 46 | experiment_dir: Directory where experiment outputs will be saved. |
47 | 47 | project_name: Name of the project for organizing experiments. |
48 | 48 | experiment_name: Unique name for this specific experiment run. |
@@ -86,7 +86,7 @@ def __init__( |
86 | 86 | ) -> None: |
87 | 87 |
|
88 | 88 | if base_yaml_config is None: |
89 | | - base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) |
| 89 | + base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_swarm_default.yaml")) |
90 | 90 | else: |
91 | 91 | logger.warning(f"Reading config from {base_yaml_config}.") |
92 | 92 | time.sleep(1) |
@@ -121,7 +121,7 @@ def __init__( |
121 | 121 | self.max_model_len: int = cast(int, max_model_len) |
122 | 122 | self.mini_batch_num: int = cast(int, mini_batch_num) |
123 | 123 |
|
124 | | - # see `ajet/default_config/ajet_ts_default.yaml` |
| 124 | + # see `ajet/default_config/ajet_swarm_default.yaml` |
125 | 125 | overrides = { |
126 | 126 | # left: [yaml key navigation] right: [AgentJetJob self attr] |
127 | 127 | "ajet.experiment_dir": "experiment_dir", |
|
0 commit comments