feat(archon): implement LoRA infrastructure with FSDP2/DTensor compatibility and PEFT checkpointing#1015
feat(archon): implement LoRA infrastructure with FSDP2/DTensor compatibility and PEFT checkpointing#1015MikaStars39 wants to merge 1 commit intoinclusionAI:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Archon engine by introducing a robust LoRA (Low-Rank Adaptation) infrastructure. It provides a parallel-aware implementation that integrates seamlessly with existing Tensor Parallelism and FSDP2 setups, crucially addressing a known deadlock. The changes also enable PEFT-compatible checkpointing for LoRA adapter weights, allowing for interoperability with HuggingFace ecosystems. Additionally, it refines weight synchronization mechanisms and includes a critical bug fix related to gradient norm calculation in distributed environments. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive LoRA (Low-Rank Adaptation) infrastructure compatible with FSDP2 and DTensor, which is a significant feature. The implementation correctly handles the complexities of distributed training, including a fix for a potential deadlock. It also adds PEFT-compatible checkpointing, which is great for interoperability. The code is generally well-structured, but there are a few areas for improvement regarding code style, clarity, and correctness in tests. Most notably, several new test files appear to be written for a different, outdated implementation of LoRALinear and will not work with the submitted code. This is a critical issue that needs to be addressed to ensure the new functionality is properly tested.
| def test_freeze_non_lora_params_keeps_only_adapter_trainable(): | ||
| model = _ToyBlock() | ||
| engine = _make_engine(model, ["wq"]) | ||
|
|
||
| engine._apply_lora() | ||
| engine._freeze_non_lora_params() | ||
|
|
||
| assert model.wq.weight.requires_grad is False | ||
| assert model.wq.lora_a.weight.requires_grad is True | ||
| assert model.wq.lora_b.weight.requires_grad is True | ||
| assert model.other.weight.requires_grad is False | ||
| assert model.inner.wv.weight.requires_grad is False |
There was a problem hiding this comment.
This test appears to be written for a different implementation of LoRALinear. It accesses model.wq.lora_a.weight, but the LoRALinear implementation in this PR does not have a lora_a submodule; it uses a plain tensor attribute _lora_a_weight. This test will fail and does not correctly validate the freezing logic for the submitted code. The test needs to be updated to access the LoRA weights correctly (e.g., model.wq._lora_a_weight).
| @@ -0,0 +1,528 @@ | |||
| """Unit tests for LoRALinear module and adapter utilities.""" | |||
There was a problem hiding this comment.
The tests in this file appear to be written for a different implementation of LoRALinear. The current implementation in areal/experimental/models/archon/lora/lora_linear.py stores LoRA weights as plain tensor attributes (e.g., _lora_a_weight) to avoid FSDP hooks. However, these tests attempt to access them as if they were nn.Linear submodules (e.g., lora_linear.lora_a.weight).
This mismatch means the tests will fail and are not validating the submitted code. The tests need to be updated to reflect the actual LoRALinear implementation. For example, lora_linear.lora_b.weight should be lora_linear._lora_b_weight.
| @dataclass | ||
| class LoRAConfig: | ||
| enabled: bool | ||
| rank: int | ||
| alpha: float | ||
| target_modules: list[str] | ||
|
|
||
| self.lora_config = LoRAConfig( | ||
| enabled=True, |
| ) | ||
| sync_lora_grads( | ||
| self.model, |
| if self.lora_config is not None: | ||
| from areal.experimental.models.archon.lora import LoRALinear |
There was a problem hiding this comment.
The method _get_all_parameters is type-hinted to return list[nn.Parameter], but it is being extended with torch.Tensor objects from module.lora_parameters(). This creates a type inconsistency. Please update the type hint to reflect the actual return type, for example list[torch.Tensor] or typing.Union[nn.Parameter, torch.Tensor].
| if self.lora_config is not None: | |
| from areal.experimental.models.archon.lora import LoRALinear | |
| def _get_all_parameters(self) -> list[torch.Tensor]: | |
| params: list[torch.Tensor] = [p for m in self.model_parts for p in m.parameters()] |
| if self._tp_enabled: | ||
| result = self._tp_lora_forward(x, base_out) | ||
| if result.requires_grad and hasattr(self, "_debug_name"): | ||
| _name = self._debug_name |
| forward_only: bool, | ||
| ) -> list[torch.Tensor | dict[int, torch.Tensor]]: | ||
| results: list[torch.Tensor | dict[int, torch.Tensor]] = [] | ||
| total_mbs = len(mb_list) |
| total_mbs = len(mb_list) | ||
|
|
||
| for mb_item in mb_list: | ||
| for mb_idx, mb_item in enumerate(mb_list): |
| with stats_tracker.scope("update"): | ||
| # Get current version for proximal approximation metrics | ||
| current_version = self.engine.get_version() | ||
| _n_mbs = len(mb_inputs.mbs) |
| _n_mbs = len(mb_inputs.mbs) | ||
|
|
||
| for mb in mb_inputs.mbs: | ||
| for _mb_idx, mb in enumerate(mb_inputs.mbs): |
|
|
||
| # lora_b initialized to zeros | ||
| assert torch.allclose( | ||
| lora_linear.lora_b.weight, torch.zeros_like(lora_linear.lora_b.weight) |
There was a problem hiding this comment.
Bug: Tests reference non-existent lora_a.weight / lora_b.weight attributes
LoRALinear stores LoRA weights as _lora_a_weight and _lora_b_weight via object.__setattr__. There are no lora_a or lora_b properties, sub-modules, or __getattr__ overrides in the class. Accessing lora_linear.lora_a will raise AttributeError through nn.Module.__getattr__.
This affects 40+ references across test_lora_linear.py and test_archon_engine_lora.py — all test files will fail at runtime.
# Test code expects:
lora_linear.lora_b.weight # AttributeError: 'LoRALinear' has no attribute 'lora_b'
# Actual implementation stores:
object.__setattr__(self, "_lora_a_weight", _a) # accessed as self._lora_a_weightSuggestion: Either add @property accessors to LoRALinear that return proxy objects with a .weight attribute, or update all test references to use _lora_a_weight / _lora_b_weight directly.
| if result.requires_grad and hasattr(self, "_debug_name"): | ||
| _name = self._debug_name | ||
|
|
||
| result.register_hook(lambda grad: grad) |
There was a problem hiding this comment.
Nit: No-op debug hook left in hot path
This identity lambda creates a new closure and registers an autograd hook on every forward pass when _tp_enabled and _debug_name exist:
if result.requires_grad and hasattr(self, "_debug_name"):
_name = self._debug_name # captured but never used
result.register_hook(lambda grad: grad) # identity — adds overheadThis adds unnecessary overhead to every backward pass. The captured _name variable is unused. This appears to be leftover debug code.
Suggestion: Remove the debug hook entirely, or gate it behind a debug flag (e.g., TORCH_DISTRIBUTED_DEBUG).
| for m in self.model_parts: | ||
| for module in m.modules(): | ||
| if isinstance(module, LoRALinear): | ||
| params.extend(module.lora_parameters()) |
There was a problem hiding this comment.
Type annotation mismatch: _get_all_parameters returns mixed types
The method signature at line 1083 declares -> list[nn.Parameter] but when LoRA is enabled, it extends with plain torch.Tensor objects from lora_parameters():
def _get_all_parameters(self) -> list[nn.Parameter]: # annotation says nn.Parameter
params = [p for m in self.model_parts for p in m.parameters()]
if self.lora_config is not None:
...
params.extend(module.lora_parameters()) # returns list[torch.Tensor], not nn.Parameter
return paramsDownstream consumers (fsdp2_clip_grad_norm, optimizer creation, gradient zeroing) may assume nn.Parameter semantics. Plain tensors behave differently for .grad accumulation and optimizer param group tracking.
Suggestion: Update the return type to list[nn.Parameter | torch.Tensor] and audit all call sites to ensure they handle both types correctly.
| """Return list of adapter parameter names relative to this module. | ||
|
|
||
| Returns: | ||
| List of parameter names (e.g., ["lora_a.weight", "lora_b.weight"]) |
There was a problem hiding this comment.
Nit: Docstring example doesn't match actual return values
The AdapterModule protocol docstring says:
Returns:
List of parameter names (e.g., ["lora_a.weight", "lora_b.weight"])But the actual LoRALinear.adapter_params() implementation returns ["_lora_a_weight", "_lora_b_weight"] (underscore-prefixed, no .weight suffix). The get_adapter_params function uses these names with getattr(), so the implementation works correctly, but the protocol docstring is misleading and may confuse future implementors.
Suggestion: Update the docstring example to ["_lora_a_weight", "_lora_b_weight"] or note that the naming convention depends on the implementation.
garrett4wade
left a comment
There was a problem hiding this comment.
Hi @MikaStars39 , thanks for the great contribution! It reduces many of our workloads.
There are still several critical issues to be addressed, as listed in the comments. We need to be more rigorous about the forward/backward computation with a multi-layer model, and with parallelism enabled.
In addition, please reabse the latest main and squash the commit messages. There are some pre-existing changes caused by merging in this PR.
| if hasattr(config, "use_lora") and config.use_lora: | ||
| from dataclasses import dataclass | ||
|
|
||
| @dataclass | ||
| class LoRAConfig: | ||
| enabled: bool | ||
| rank: int | ||
| alpha: float | ||
| target_modules: list[str] | ||
|
|
||
| self.lora_config = LoRAConfig( | ||
| enabled=True, | ||
| rank=config.lora_rank, | ||
| alpha=float(config.lora_alpha), | ||
| target_modules=config.target_modules if config.target_modules else [], | ||
| ) |
There was a problem hiding this comment.
Should place this config object in cli_args.py and totally refactor the use_lora field in TrainEngineConfig.
| # Save weights (only rank 0) | ||
| if rank == 0: | ||
| weights_path = os.path.join(path, "adapter_model.safetensors") | ||
| save_file(peft_state, weights_path) | ||
| logger.info(f"Saved {len(peft_state)} adapter tensors to {weights_path}") | ||
|
|
||
| # Determine target modules from actual adapter parameters | ||
| target_modules = set() | ||
| for key in adapter_params: | ||
| parts = key.split(".") | ||
| for i, part in enumerate(parts): | ||
| is_lora = part in ("lora_a", "lora_b") or part.startswith("_lora_") | ||
| if is_lora and i > 0: | ||
| module_name = parts[i - 1] | ||
| target_modules.add(module_name) | ||
| break | ||
|
|
There was a problem hiding this comment.
IMO the current implementation only supports TP/CP/FSDP. Should marked with TODO
| results: list[torch.Tensor | dict[int, torch.Tensor]] = [] | ||
| total_mbs = len(mb_list) | ||
|
|
||
| for mb_item in mb_list: | ||
| for mb_idx, mb_item in enumerate(mb_list): | ||
| inputs, ctx = self.prepare_inputs_fn(mb_item) | ||
|
|
||
| tree_attn_meta = None | ||
| if ctx.trie_node is not None: | ||
| padded_size = mb_item.padded_to_length | ||
| assert padded_size is not None | ||
| tree_attn_meta = TreeAttentionMeta.from_trie( | ||
| ctx.trie_node, padded_size, inputs["input_ids"].device | ||
| ) | ||
| # Tree attention uses tree_attn_meta instead of cu_seqlens; | ||
| # create dummy cu_seqlens for model interface compatibility. | ||
| seq_len = inputs["input_ids"].shape[-1] | ||
| cu_seqlens = torch.tensor( | ||
| [0, seq_len], dtype=torch.int32, device=inputs["input_ids"].device | ||
| ) | ||
| max_seqlen = seq_len | ||
| else: | ||
| cu_seqlens = inputs["cu_seqlens"] | ||
| max_seqlen = int(inputs["max_seqlen"]) | ||
|
|
||
| logits = self.model( | ||
| inputs["input_ids"], | ||
| inputs["position_ids"], | ||
| cu_seqlens=cu_seqlens, | ||
| max_seqlen=max_seqlen, | ||
| tree_attn_meta=tree_attn_meta, | ||
| ) | ||
| logits = logits.squeeze(0) | ||
| del tree_attn_meta | ||
|
|
||
| result = process_output_fn(logits, ctx.to_dict()) | ||
|
|
||
| if result is not None: | ||
| if forward_only: | ||
| # Result can be a tensor or dict (for tree training) | ||
| if isinstance(result, dict): |
| save_model_to_hf(engine, meta.path, engine.tokenizer, None) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| ready_path = os.path.join(meta.path, WEIGHT_UPDATE_READY_FILE) | ||
| ready_tmp_path = ready_path + ".tmp" | ||
| ready_timestamp = str(datetime.now().timestamp()) | ||
| with open(ready_tmp_path, "w") as f: | ||
| f.write(ready_timestamp) | ||
| os.replace(ready_tmp_path, ready_path) | ||
|
|
||
| update_name = names.update_weights_from_disk( | ||
| engine.config.experiment_name, | ||
| engine.config.trial_name, | ||
| engine.get_version(), | ||
| ) | ||
| name_resolve.add( | ||
| update_name, str(datetime.now().timestamp()), keepalive_ttl=120 | ||
| update_name, ready_timestamp, keepalive_ttl=600 | ||
| ) | ||
|
|
There was a problem hiding this comment.
Should rebase from main
|
|
||
| local_w = getattr(linear.weight, "_local_tensor", linear.weight) | ||
| _a = torch.empty(rank, linear.in_features, device=local_w.device, dtype=local_w.dtype) | ||
| _b = torch.empty(linear.out_features, rank, device=local_w.device, dtype=local_w.dtype) | ||
| _a.requires_grad_(True) | ||
| _b.requires_grad_(True) | ||
| object.__setattr__(lora_linear, "_lora_a_weight", _a) | ||
| object.__setattr__(lora_linear, "_lora_b_weight", _b) |
There was a problem hiding this comment.
Are init parameters all-reduced when TP is enabled?
| current_version = self.engine.get_version() | ||
| _n_mbs = len(mb_inputs.mbs) | ||
|
|
||
| for mb in mb_inputs.mbs: | ||
| for _mb_idx, mb in enumerate(mb_inputs.mbs): | ||
| train_stat = self.engine.train_batch( | ||
| mb, | ||
| loss_fn=functools.partial( |
|
|
||
| # Wait for async checkpoint staging to complete before modifying parameters | ||
| self.saver.maybe_wait_for_staging() |
|
|
||
| class _ToyBlock(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.wq = nn.Linear(8, 8) | ||
| self.other = nn.Linear(8, 8) | ||
| self.inner = nn.Module() | ||
| self.inner.wv = nn.Linear(8, 8) | ||
|
|
||
|
|
There was a problem hiding this comment.
We need a multi-layer, end-to-end forward/backward precision test for scientific rigor. You can refer to existing tests for the archon engine.
| # Try to import PEFT for compatibility tests | ||
| try: | ||
| import peft # noqa: F401 | ||
|
|
||
| PEFT_AVAILABLE = True | ||
| except ImportError: | ||
| PEFT_AVAILABLE = False | ||
|
|
There was a problem hiding this comment.
We may want to add an end-to-end test to load a model from native peft/transformers API.
Hi @garrett4wade , thanks for the code review! I will address these issues and rebase / squash. Also there will be a new training test to examine if the training works well after fixing these issues. |
|
nit: just a link to the issue - this PR fixes #1040 |
…frastructure with FSDP2-DTensor deadlock fix
|
This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days. Please add a comment or push new commits to keep it active. Thank you for your contribution! |

Summary
This PR introduces Phase 1 & 2 of the LoRA (Low-Rank Adaptation) infrastructure for the Archon engine. It provides a robust, parallel-aware implementation of LoRA that seamlessly integrates with Tensor Parallelism (TP) and FSDP2. Crucially, it resolves a known deadlock issue between FSDP2 Data Parallel (DP) reduce-scatter and DTensor TP operations during the backward pass. It also introduces HuggingFace PEFT-compatible checkpointing for adapter weights.
Key Features & Architectural Changes
FSDP2-Safe LoRA Implementation (
lora_linear.py): * Implemented a customLoRALinearmodule.Deadlock Fix: LoRA weights ($A$ and $B$ matrices) are stored as plain tensors (via
object.__setattr__) rather thannn.Parameter. This intentional design bypasses FSDP2'spost_accumulate_grad_hook, preventing the DP reduce-scatter operations from interleaving with DTensor TP operations, which previously caused diamond deadlocks.Added
sync_lora_gradsto manually all-reduce LoRA weight gradients across both TP and DP groups before the optimizer step.Archon Engine Integration (
archon_engine.py):Added dynamic LoRA application (
_apply_lora) to target modules based onLoRAConfig.Implemented
_freeze_non_lora_paramsto lock base model weights while keeping adapter parameters trainable.Integrated LoRA initialization with the existing parallelization pipeline, ensuring LoRA is injected after TP/CP so that tensor-parallel planning operates correctly on
nn.Linear.PEFT-Compatible Checkpointing (
archon_lora_checkpoint.py&base.py):LoRA adapters are saved and loaded in HuggingFace's PEFT format (
adapter_model.safetensorsandadapter_config.json).Introduced module name mapping in
Qwen2StateDictAdapterto automatically translate Archon-specific FQN paths (e.g.,layers.0.attention.wq.lora_a) to HF PEFT paths (e.g.,self_attn.q_proj.lora_A).Added stripping of adapter parameters from base HuggingFace checkpoints during the initial load to prevent missing key errors.
Weight Sync & Reliability (
archon_weight_sync.py&remote_inf_engine.py):Improved the reliability of cross-node weight synchronization by implementing an atomic swap (
.tmpto final) for the.areal_weight_update_readysignal file.Updated the remote inference engine to prioritize checking the disk-based ready file over the legacy name-resolve key to prevent timeouts.
Bug Fixes:
Gradient Norm (
grad.py): Fixed a hanging issue duringget_grad_norm_fp32whengrads_for_normis empty on certain ranks by ensuring they still participate in theall_reducewith a zero contribution.Removed stale training debug info for cleaner logging.
Testing
LoRALinearforward/backward passes, dropout behavior, and PEFT mathematical equivalence.