refactor(experimental): consolidate DTA Archon integration#1391
refactor(experimental): consolidate DTA Archon integration#1391ezoicoder wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Dynamic Tree Attention (DTA) as a new tree training mode, replacing the boolean enable_tree_training flag with a multi-option tree_training_mode string. It adds the areal/experimental/dta module, integrates DTA into the Archon engine via a DTAWrapper, and updates attention mechanisms and Qwen2/Qwen3 models to support KV-cache attention. The review feedback highlights critical runtime AttributeError risks across the Qwen2/Qwen3 models and the DTA runner, where the code incorrectly assumes DynamicCache has a .layers attribute instead of using its standard key_cache and value_cache structures.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| past_len = 0 | ||
| if len(past_key_values.layers) > 0: | ||
| past_len = int(past_key_values.layers[0].keys.shape[2]) |
There was a problem hiding this comment.
The code assumes past_key_values has a .layers attribute. However, the standard transformers.cache_utils.DynamicCache class stores key/value states in key_cache and value_cache lists and does not have a .layers attribute. This will cause an AttributeError at runtime. Use get_seq_length() instead.
| past_len = 0 | |
| if len(past_key_values.layers) > 0: | |
| past_len = int(past_key_values.layers[0].keys.shape[2]) | |
| past_len = past_key_values.get_seq_length() |
| if past_key_values is not None and layer_idx < len(past_key_values.layers): | ||
| layer_entry = past_key_values.layers[layer_idx] | ||
| layer_past = (layer_entry.keys, layer_entry.values) |
There was a problem hiding this comment.
The code assumes past_key_values has a .layers attribute. Standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.
| if past_key_values is not None and layer_idx < len(past_key_values.layers): | |
| layer_entry = past_key_values.layers[layer_idx] | |
| layer_past = (layer_entry.keys, layer_entry.values) | |
| if past_key_values is not None and layer_idx < len(past_key_values): | |
| layer_past = (past_key_values.key_cache[layer_idx], past_key_values.value_cache[layer_idx]) |
| past_len = 0 | ||
| if len(past_key_values.layers) > 0: | ||
| past_len = int(past_key_values.layers[0].keys.shape[2]) |
There was a problem hiding this comment.
The code assumes past_key_values has a .layers attribute. However, the standard transformers.cache_utils.DynamicCache class stores key/value states in key_cache and value_cache lists and does not have a .layers attribute. This will cause an AttributeError at runtime. Use get_seq_length() instead.
| past_len = 0 | |
| if len(past_key_values.layers) > 0: | |
| past_len = int(past_key_values.layers[0].keys.shape[2]) | |
| past_len = past_key_values.get_seq_length() |
| if past_key_values is not None and layer_idx < len(past_key_values.layers): | ||
| layer_entry = past_key_values.layers[layer_idx] | ||
| layer_past = (layer_entry.keys, layer_entry.values) |
There was a problem hiding this comment.
The code assumes past_key_values has a .layers attribute. Standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.
| if past_key_values is not None and layer_idx < len(past_key_values.layers): | |
| layer_entry = past_key_values.layers[layer_idx] | |
| layer_past = (layer_entry.keys, layer_entry.values) | |
| if past_key_values is not None and layer_idx < len(past_key_values): | |
| layer_past = (past_key_values.key_cache[layer_idx], past_key_values.value_cache[layer_idx]) |
| new_cache = out.past_key_values | ||
| for layer_idx, layer in enumerate(new_cache.layers): | ||
| self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[ | ||
| :, :, start:end, : | ||
| ] | ||
| self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[ | ||
| :, :, start:end, : | ||
| ] |
There was a problem hiding this comment.
The code assumes out.past_key_values has a .layers attribute. However, standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.
| new_cache = out.past_key_values | |
| for layer_idx, layer in enumerate(new_cache.layers): | |
| self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[ | |
| :, :, start:end, : | |
| ] | |
| self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[ | |
| :, :, start:end, : | |
| ] | |
| new_cache = out.past_key_values | |
| for layer_idx in range(len(new_cache)): | |
| self.kv_cache[0][layer_idx][:, :, start:end, :] = new_cache.key_cache[layer_idx][ | |
| :, :, start:end, : | |
| ] | |
| self.kv_cache[1][layer_idx][:, :, start:end, :] = new_cache.value_cache[layer_idx][ | |
| :, :, start:end, : | |
| ] |
| for layer_idx, layer in enumerate(block_cache.layers): | ||
| k = layer.keys[:, :, start:end, :] | ||
| v = layer.values[:, :, start:end, :] | ||
| roots.extend([k, v]) | ||
| grads.extend( | ||
| [ | ||
| self.grad_kv[0][layer_idx][:, :, start:end, :], | ||
| self.grad_kv[1][layer_idx][:, :, start:end, :], | ||
| ] | ||
| ) |
There was a problem hiding this comment.
The code assumes block_cache has a .layers attribute. However, standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.
| for layer_idx, layer in enumerate(block_cache.layers): | |
| k = layer.keys[:, :, start:end, :] | |
| v = layer.values[:, :, start:end, :] | |
| roots.extend([k, v]) | |
| grads.extend( | |
| [ | |
| self.grad_kv[0][layer_idx][:, :, start:end, :], | |
| self.grad_kv[1][layer_idx][:, :, start:end, :], | |
| ] | |
| ) | |
| for layer_idx in range(len(block_cache)): | |
| k = block_cache.key_cache[layer_idx][:, :, start:end, :] | |
| v = block_cache.value_cache[layer_idx][:, :, start:end, :] | |
| roots.extend([k, v]) | |
| grads.extend( | |
| [ | |
| self.grad_kv[0][layer_idx][:, :, start:end, :], | |
| self.grad_kv[1][layer_idx][:, :, start:end, :], | |
| ] | |
| ) |
e0c0e7e to
5fdea9c
Compare
Integrate the Dynamic Tree Attention training path with Archon DP while keeping unsupported engines explicit. Key changes: - Add DTA trie, runner, allocation, rollout, and Zero1 wrapper utilities - Route Archon tree_training_mode='dta' through DTA-specific batch handling - Add DTA examples, docs, and distributed engine-step coverage
5fdea9c to
72acd4b
Compare
Description
Consolidates Dynamic Token Alignment (DTA) support into the experimental Archon
path. This adds DTA allocation, trie construction, runner/wrapper integration,
rollout preparation, examples, documentation, and focused regression coverage.
This update also fixes DTA microbatch construction so one-sequence-per-microbatch
batches stay per-rank independent instead of being forced through cross-rank
microbatch-count synchronization. The torchrun DTA case uses 17 turns so two
ranks receive uneven local sequence counts.
The latest update adds an end-to-end global loss comparison for the DTA
engine-step test. The torchrun runner now all-reduces each rank's returned local
loss contribution into
stats["global_loss"], and the pytest comparison checksbaseline Archon DP against DTA.
Related Issue
N/A
Type of Change
Checklist
pre-commit run --all-files)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A
Additional Context
Validation run:
Targeted torchrun result:
Skipped full repository test and docs build suites in this pass due scope and
runtime; the targeted DTA engine-step regression was run on the multi-GPU path.