Skip to content

Commit e0c0e7e

Browse files
committed
refactor(experimental): consolidate DTA Archon integration
Address PR #1287 review by keeping DTA-specific behavior in experimental modules and narrow Archon integration points. Key changes: - Add DTA allocation, trie, runner, wrapper, Zero1, and rollout helpers - Route controller and rollout allocation through a trajectory adapter - Add Qwen2/Qwen3 cache support and reject unsupported Qwen3.5 DTA - Update DTA docs, examples, and focused integration tests Refs: #1287
1 parent 1cb9db2 commit e0c0e7e

48 files changed

Lines changed: 4420 additions & 235 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

areal/api/cli_args.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,10 +1170,42 @@ class TrainEngineConfig:
11701170
metadata={"help": "peft method type. Only LoRA is supported for now."},
11711171
)
11721172

1173-
# Tree training
1174-
enable_tree_training: bool = field(
1175-
default=False,
1176-
metadata={"help": "Enable tree training with flex attention module."},
1173+
# Tree training (str, not Literal: OmegaConf.structured rejects Literal here)
1174+
tree_training_mode: str = field(
1175+
default="disabled",
1176+
metadata={
1177+
"help": (
1178+
"Tree training mode. "
1179+
"'sparse' enables tree training with Flex Attention module (flex attention), "
1180+
"'dta' enables Dynamic Tree Attention (dynamic tree training), "
1181+
"'disabled' disables tree training."
1182+
),
1183+
"choices": ["disabled", "sparse", "dta"],
1184+
},
1185+
)
1186+
dta_block_size: int = field(
1187+
default=2048,
1188+
metadata={
1189+
"help": (
1190+
"Block size for Dynamic Tree Attention. "
1191+
"Set to -1 to disable block-size limit. "
1192+
"Only effective when tree_training_mode='dta'."
1193+
)
1194+
},
1195+
)
1196+
packing_algorithm: str = field(
1197+
default="ffd",
1198+
metadata={
1199+
"help": (
1200+
"Trajectory packing across data-parallel ranks during distributed rollout "
1201+
"(``redistribute_trajectories``). "
1202+
"'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order "
1203+
"n_tree_tokens. "
1204+
"Not to be confused with ``mb_spec.packing_algorithm``, which only "
1205+
"controls micro-batch formation (ffd/kk) during training."
1206+
),
1207+
"choices": ["ffd", "kk", "dta"],
1208+
},
11771209
)
11781210

11791211
# Scheduling
@@ -1246,6 +1278,23 @@ def __post_init__(self):
12461278
"memory_efficient_load is for loading pretrained weights on CPU, "
12471279
"but init_from_scratch creates a model without loading any weights."
12481280
)
1281+
valid_tree_modes = {"disabled", "sparse", "dta"}
1282+
if self.tree_training_mode not in valid_tree_modes:
1283+
raise ValueError(
1284+
f"tree_training_mode must be one of {valid_tree_modes}, got '{self.tree_training_mode}'"
1285+
)
1286+
valid_rollout_packing = {"ffd", "kk", "dta"}
1287+
if self.packing_algorithm not in valid_rollout_packing:
1288+
raise ValueError(
1289+
f"packing_algorithm (rollout) must be one of {valid_rollout_packing}, "
1290+
f"got '{self.packing_algorithm}'"
1291+
)
1292+
if self.tree_training_mode == "dta":
1293+
if self.dta_block_size == 0 or self.dta_block_size < -1:
1294+
raise ValueError(
1295+
f"dta_block_size must be -1 or a positive integer when tree_training_mode='dta', got {self.dta_block_size}."
1296+
)
1297+
12491298
if self._version not in ("v1", "v2"):
12501299
raise ValueError(
12511300
f"_version must be either 'v1' or 'v2', got '{self._version}'"
@@ -1635,6 +1684,22 @@ def __post_init__(self):
16351684
"Please set `actor.use_decoupled_loss=false` in your configuration."
16361685
)
16371686

1687+
if self.packing_algorithm == "dta":
1688+
for norm_name in ["adv_norm", "reward_norm"]:
1689+
norm_config = getattr(self, norm_name)
1690+
if norm_config is not None:
1691+
if (
1692+
norm_config.mean_level == "group"
1693+
or norm_config.std_level == "group"
1694+
):
1695+
raise ValueError(
1696+
f"{norm_name} uses 'group' level normalization, which is incompatible "
1697+
"with packing_algorithm='dta'. DTA requires sequence-level independence, "
1698+
"but 'group' normalization relies on contiguous group slices. Please use "
1699+
"'batch' level normalization or set packing_algorithm='ffd'. "
1700+
"(Group-level support for DTA will be provided in a future release.)"
1701+
)
1702+
16381703
super().__post_init__()
16391704

16401705

areal/engine/fsdp_engine.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,14 @@ def __init__(self, config: TrainEngineConfig):
262262
self.dp_rank: int
263263

264264
self.is_offload: bool = False
265+
self.tree_training_mode: str = self.config.tree_training_mode
266+
if self.tree_training_mode == "dta":
267+
raise ValueError(
268+
"tree_training_mode='dta' is only supported by ArchonEngine. "
269+
"Please use Archon backend or set tree_training_mode to 'disabled'/'sparse'."
270+
)
265271
self._offload_depth: int = 0
266272
self._per_layer_optim_wrapper: PerLayerOptimWrapper | None = None
267-
self.enable_tree_training: bool = self.config.enable_tree_training
268273

269274
@classmethod
270275
def from_pretrained(
@@ -384,7 +389,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
384389
# Create device model
385390
self._create_device_model()
386391

387-
if self.enable_tree_training and self.parallel_helper.sp_size > 1:
392+
if self.tree_training_mode == "sparse" and self.parallel_helper.sp_size > 1:
388393
raise ValueError(
389394
"Tree training currently cannot be enabled with sp_size > 1."
390395
)
@@ -395,7 +400,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
395400
shard_vision_across_sp=self.config.fsdp.shard_vision_across_sp,
396401
)
397402
# Monkey patch: replace attention's forward() with tree attention.
398-
patch_fsdp_for_tree_training(enable=self.enable_tree_training)
403+
patch_fsdp_for_tree_training(enable=self.tree_training_mode == "sparse")
399404

400405
if self.config.use_lora:
401406
self._apply_peft_wrapper()
@@ -733,7 +738,7 @@ def forward_backward_batch(
733738
# module_fsdp.py reads these keys from the **kwargs that transformers
734739
# forwards through.
735740
tree_attn_keys: list[str] = []
736-
if self.enable_tree_training and ctx.trie_node is not None:
741+
if self.tree_training_mode == "sparse" and ctx.trie_node is not None:
737742
padded_size = mb_item.padded_to_length
738743
assert padded_size is not None
739744
tree_kwargs = build_tree_attn_kwargs(
@@ -881,8 +886,8 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None:
881886
self.forward_backward_batch(mb_list, process_output, forward_only=True)
882887

883888
# Step 4: Aggregate and reorder outputs
884-
if self.enable_tree_training:
885-
result = merge_packed_tree_results(outputs, batch_size)
889+
if self.tree_training_mode == "sparse":
890+
return merge_packed_tree_results(outputs, batch_size)
886891
else:
887892
result = reorder_and_pad_outputs(
888893
outputs, output_seqlens, mb_list, aggregate_fn
@@ -1794,7 +1799,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
17941799
input_ = input_.copy()
17951800

17961801
# Tree training path
1797-
if self.enable_tree_training:
1802+
if self.tree_training_mode == "sparse":
17981803
mb_list = build_packed_tree_batch(
17991804
input_,
18001805
mb_spec=self.config.mb_spec,
@@ -2063,12 +2068,12 @@ def _compute_logprobs_and_loss(
20632068
if local_weight == 0:
20642069
return logits.mean() * 0.0
20652070

2066-
if self.config.is_critic and self.enable_tree_training:
2071+
if self.config.is_critic and self.tree_training_mode == "sparse":
20672072
raise NotImplementedError(
20682073
"Tree training with critic model is not supported yet."
20692074
)
20702075
if not self.config.is_critic:
2071-
if self.enable_tree_training:
2076+
if self.tree_training_mode == "sparse":
20722077
# Handle dummy trie (empty tree for DP synchronization)
20732078
# When trie has no sequences, return zero loss with grad connection
20742079
if ctx.trie_node is None or not ctx.trie_node.all_sequence_ids:
@@ -2126,12 +2131,12 @@ def _compute_forward_result(
21262131
ctx: FSDPTrainContext,
21272132
) -> torch.Tensor | dict[int, torch.Tensor]:
21282133
"""Compute forward output (logprobs or values)."""
2129-
if self.config.is_critic and self.enable_tree_training:
2134+
if self.config.is_critic and self.tree_training_mode == "sparse":
21302135
raise NotImplementedError(
21312136
"Tree training with critic model is not supported yet."
21322137
)
21332138
if not self.config.is_critic:
2134-
if self.enable_tree_training:
2139+
if self.tree_training_mode == "sparse":
21352140
result = _gather_packed_tree_logprobs(
21362141
logits,
21372142
ctx.trie_node,

areal/engine/megatron_engine.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,13 @@ def __init__(self, config: TrainEngineConfig):
197197
self.seed: int = 0
198198
self.own_global_group: bool = False
199199
self.is_offload: bool = False
200+
self.tree_training_mode: str = self.config.tree_training_mode
201+
if self.tree_training_mode == "dta":
202+
raise ValueError(
203+
"tree_training_mode='dta' is only supported by ArchonEngine. "
204+
"Please use Archon backend or set tree_training_mode to 'disabled'/'sparse'."
205+
)
200206
self._offload_depth: int = 0
201-
self.enable_tree_training: bool = self.config.enable_tree_training
202207
# FP8 configuration
203208
self.fp8_config = self.mcore_config.fp8_config
204209
self.enable_fp8: bool = self.fp8_config is not None
@@ -331,7 +336,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
331336
self.tokenizer = load_hf_tokenizer(self.config.path)
332337

333338
with patch_bridge_for_tree_training(
334-
self.enable_tree_training and self.bridge_cls == "mbridge"
339+
self.tree_training_mode == "sparse" and self.bridge_cls == "mbridge"
335340
):
336341
self.bridge = self._build_hf_mcore_bridge()
337342

@@ -530,7 +535,7 @@ def _build_hf_mcore_bridge(self):
530535
)
531536

532537
elif self.bridge_cls == "megatron-bridge":
533-
if self.enable_tree_training:
538+
if self.tree_training_mode == "sparse":
534539
raise NotImplementedError(
535540
"Tree training is not supported with bridge_type='megatron-bridge'."
536541
)
@@ -819,7 +824,7 @@ def forward_step(batch_iter, model):
819824
# save_for_backward() which can only save torch.Tensor objects;
820825
# BlockMask is recreated inside PytorchFlexAttention.forward().
821826
tree_attn_keys: list[str] = []
822-
if self.enable_tree_training:
827+
if self.tree_training_mode == "sparse":
823828
trie_node = mb_input.padded_mb.get("trie_node", None)
824829
# Ensure trie_node is also in orig_mb for _compute_logprobs_and_loss
825830
if trie_node is not None and "trie_node" not in mb_input.orig_mb:
@@ -1046,7 +1051,7 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None:
10461051
# Step 4: Aggregate, reorder, and broadcast outputs
10471052
res = None
10481053
if mpu.is_pipeline_last_stage():
1049-
if self.enable_tree_training:
1054+
if self.tree_training_mode == "sparse":
10501055
res = merge_packed_tree_results(outputs, batch_size)
10511056
else:
10521057
res = reorder_and_pad_outputs(
@@ -1926,7 +1931,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
19261931
pp_size = self.parallel_strategy.pipeline_parallel_size
19271932
cp_size = self.parallel_strategy.context_parallel_size
19281933
tp_size = self.parallel_strategy.tensor_parallel_size
1929-
if self.enable_tree_training:
1934+
if self.tree_training_mode == "sparse":
19301935
assert cp_size == 1, (
19311936
"Context parallelism is not supported in tree training."
19321937
)
@@ -2036,12 +2041,12 @@ def _compute_logprobs_and_loss(
20362041
if local_weight == 0:
20372042
return output.mean() * 0.0
20382043

2039-
if self.config.is_critic and self.enable_tree_training:
2044+
if self.config.is_critic and self.tree_training_mode == "sparse":
20402045
raise NotImplementedError(
20412046
"Tree training with critic model is not supported yet."
20422047
)
20432048
if not self.config.is_critic:
2044-
if self.enable_tree_training:
2049+
if self.tree_training_mode == "sparse":
20452050
# Handle dummy trie (empty tree for DP synchronization)
20462051
# When trie has no sequences, return zero loss with grad connection
20472052
trie_node = inputs.get("trie_node")
@@ -2144,12 +2149,12 @@ def _compute_forward_result(
21442149
output: torch.Tensor,
21452150
inputs: dict[str, Any],
21462151
) -> torch.Tensor | dict[int, torch.Tensor]:
2147-
if self.config.is_critic and self.enable_tree_training:
2152+
if self.config.is_critic and self.tree_training_mode == "sparse":
21482153
raise NotImplementedError(
21492154
"Tree training with critic model is not supported yet."
21502155
)
21512156
if not self.config.is_critic:
2152-
if self.enable_tree_training:
2157+
if self.tree_training_mode == "sparse":
21532158
logprobs = _gather_packed_tree_logprobs(
21542159
output,
21552160
inputs["trie_node"],

0 commit comments

Comments
 (0)