Skip to content

feat(archon): implement LoRA infrastructure with FSDP2/DTensor compatibility and PEFT checkpointing#1015

Open
MikaStars39 wants to merge 1 commit intoinclusionAI:mainfrom
MikaStars39:lora
Open

feat(archon): implement LoRA infrastructure with FSDP2/DTensor compatibility and PEFT checkpointing#1015
MikaStars39 wants to merge 1 commit intoinclusionAI:mainfrom
MikaStars39:lora

Conversation

@MikaStars39
Copy link
Copy Markdown

Summary
This PR introduces Phase 1 & 2 of the LoRA (Low-Rank Adaptation) infrastructure for the Archon engine. It provides a robust, parallel-aware implementation of LoRA that seamlessly integrates with Tensor Parallelism (TP) and FSDP2. Crucially, it resolves a known deadlock issue between FSDP2 Data Parallel (DP) reduce-scatter and DTensor TP operations during the backward pass. It also introduces HuggingFace PEFT-compatible checkpointing for adapter weights.

Key Features & Architectural Changes

  • FSDP2-Safe LoRA Implementation (lora_linear.py): * Implemented a custom LoRALinear module.

  • Deadlock Fix: LoRA weights ($A$ and $B$ matrices) are stored as plain tensors (via object.__setattr__) rather than nn.Parameter. This intentional design bypasses FSDP2's post_accumulate_grad_hook, preventing the DP reduce-scatter operations from interleaving with DTensor TP operations, which previously caused diamond deadlocks.

  • Added sync_lora_grads to manually all-reduce LoRA weight gradients across both TP and DP groups before the optimizer step.

  • Archon Engine Integration (archon_engine.py):

  • Added dynamic LoRA application (_apply_lora) to target modules based on LoRAConfig.

  • Implemented _freeze_non_lora_params to lock base model weights while keeping adapter parameters trainable.

  • Integrated LoRA initialization with the existing parallelization pipeline, ensuring LoRA is injected after TP/CP so that tensor-parallel planning operates correctly on nn.Linear.

  • PEFT-Compatible Checkpointing (archon_lora_checkpoint.py & base.py):

  • LoRA adapters are saved and loaded in HuggingFace's PEFT format (adapter_model.safetensors and adapter_config.json).

  • Introduced module name mapping in Qwen2StateDictAdapter to automatically translate Archon-specific FQN paths (e.g., layers.0.attention.wq.lora_a) to HF PEFT paths (e.g., self_attn.q_proj.lora_A).

  • Added stripping of adapter parameters from base HuggingFace checkpoints during the initial load to prevent missing key errors.

  • Weight Sync & Reliability (archon_weight_sync.py & remote_inf_engine.py):

  • Improved the reliability of cross-node weight synchronization by implementing an atomic swap (.tmp to final) for the .areal_weight_update_ready signal file.

  • Updated the remote inference engine to prioritize checking the disk-based ready file over the legacy name-resolve key to prevent timeouts.

  • Bug Fixes:

  • Gradient Norm (grad.py): Fixed a hanging issue during get_grad_norm_fp32 when grads_for_norm is empty on certain ranks by ensuring they still participate in the all_reduce with a zero contribution.

  • Removed stale training debug info for cleaner logging.

Testing

  • Added comprehensive unit tests for LoRALinear forward/backward passes, dropout behavior, and PEFT mathematical equivalence.
  • Added tests for HuggingFace/Archon state dict key translation and round-trip conversion.
  • Verified TP/CP ordering with mocked parallelize functions.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Archon engine by introducing a robust LoRA (Low-Rank Adaptation) infrastructure. It provides a parallel-aware implementation that integrates seamlessly with existing Tensor Parallelism and FSDP2 setups, crucially addressing a known deadlock. The changes also enable PEFT-compatible checkpointing for LoRA adapter weights, allowing for interoperability with HuggingFace ecosystems. Additionally, it refines weight synchronization mechanisms and includes a critical bug fix related to gradient norm calculation in distributed environments.

Highlights

  • LoRA Infrastructure: Implemented a custom LoRALinear module and integrated LoRA application into the Archon engine, ensuring compatibility with Tensor Parallelism (TP) and FSDP2. LoRA is injected after TP/CP to maintain correct tensor-parallel planning.
  • Deadlock Fix: Resolved a deadlock issue between FSDP2 Data Parallel (DP) reduce-scatter and DTensor TP operations by storing LoRA weights as plain tensors (not nn.Parameter), bypassing FSDP2's post_accumulate_grad_hook. Manual gradient all-reduction for LoRA weights is performed via sync_lora_grads.
  • PEFT-Compatible Checkpointing: Introduced functionality to save and load LoRA adapters in HuggingFace's PEFT format (adapter_model.safetensors and adapter_config.json). This includes module name mapping for Archon-specific FQN paths to HF PEFT paths and stripping adapter parameters from base HF checkpoints during initial load.
  • Weight Synchronization Reliability: Improved cross-node weight synchronization by implementing an atomic swap for the .areal_weight_update_ready signal file and updated the remote inference engine to prioritize this disk-based ready file.
  • Gradient Norm Bug Fix: Fixed a hanging issue in get_grad_norm_fp32 by ensuring all ranks, even those without gradients, participate in the all_reduce operation with a zero contribution.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • areal/engine/fsdp_utils/grad.py
    • Modified get_grad_norm_fp32 to prevent hangs by ensuring all ranks participate in all_reduce even if they have no gradients.
  • areal/experimental/engine/archon_checkpoint.py
    • Added logic to strip LoRA adapter parameters from HuggingFace state dicts before loading base models.
    • Modified missing keys filtering to exclude LoRA adapter keys during model loading.
  • areal/experimental/engine/archon_engine.py
    • Added lora_config attribute to ArchonEngine and initialized it based on TrainEngineConfig.
    • Integrated _freeze_non_lora_params call into the initialize method when LoRA is enabled.
    • Added sync_lora_grads call after forward_backward_batch to manually all-reduce LoRA gradients.
    • Modified save method to conditionally use save_lora_adapter for LoRA-enabled models.
    • Modified load method to conditionally use load_lora_adapter for PEFT adapter checkpoints.
    • Passed _apply_lora function to parallelization methods to inject LoRA after TP/CP.
    • Implemented _apply_lora to dynamically replace nn.Linear modules with LoRALinear.
    • Implemented _freeze_non_lora_params to manage LoRA parameter trainability and initialization.
    • Updated _get_all_parameters to include LoRA parameters.
  • areal/experimental/engine/archon_lora_checkpoint.py
    • Added save_lora_adapter function to save LoRA adapters in PEFT format.
    • Added load_lora_adapter function to load LoRA adapters from PEFT format checkpoints.
    • Added is_lora_adapter_checkpoint function to detect PEFT LoRA adapter checkpoints.
  • areal/experimental/engine/archon_runner.py
    • Added total_mbs and mb_idx to the minibatch processing loop.
    • Removed a redundant comment regarding result types.
  • areal/experimental/engine/archon_weight_sync.py
    • Defined WEIGHT_UPDATE_READY_FILE constant for atomic weight update signaling.
    • Modified update_weights_from_disk to use save_lora_adapter for LoRA models and implemented an atomic file-based ready signal.
  • areal/experimental/models/archon/base.py
    • Added to_peft_module_map attribute to BaseStateDictAdapter for LoRA module name mapping.
    • Added create_peft_adapter_config method to generate PEFT-compatible adapter_config.json.
  • areal/experimental/models/archon/lora/init.py
    • Created the lora package and exposed its public API for LoRA modules and utilities.
  • areal/experimental/models/archon/lora/adapter.py
    • Defined AdapterModule protocol for modules containing adapter parameters.
    • Added get_adapter_params to extract adapter parameters from a model.
    • Added set_trainable_params to freeze/unfreeze model parameters.
    • Added get_adapter_state_dict to filter state dictionaries for adapter parameters.
    • Added disable_adapter and enable_adapter functions to control LoRA adapter activation.
  • areal/experimental/models/archon/lora/lora_linear.py
    • Implemented LoRALinear module, a custom linear layer for LoRA.
    • Designed LoRALinear to store LoRA weights as plain tensors to prevent FSDP2 deadlocks.
    • Included sync_lora_grads function for manual all-reduction of LoRA gradients across TP and DP groups.
  • areal/experimental/models/archon/qwen2/infra/parallelize.py
    • Imported Callable type for function annotations.
    • Added apply_lora_fn parameter to parallelize_qwen2.
    • Called apply_lora_fn after TP/CP to ensure correct LoRA injection order.
  • areal/experimental/models/archon/qwen3/infra/parallelize.py
    • Imported Callable type for function annotations.
    • Added apply_lora_fn parameter to parallelize_qwen3.
    • Called apply_lora_fn after TP/EP/CP for correct LoRA injection order.
  • areal/infra/remote_inf_engine.py
    • Added _wait_for_disk_weight_update_ready function to prioritize file-based weight update signals.
    • Updated _wait_for_disk_weight_update_ready call and increased its timeout.
  • areal/trainer/ppo/actor.py
    • Added _n_mbs and _mb_idx variables to the minibatch loop for internal tracking.
  • areal/trainer/rl_trainer.py
    • Removed a comment related to waiting for async checkpoint staging.
  • areal/utils/logging.py
    • Added 'LoRACheckpoint' to the LOG_COLORS dictionary for colored logging output.
  • tests/experimental/archon/test_archon_engine_lora.py
    • Added unit tests for ArchonEngine's LoRA integration, covering _apply_lora and _freeze_non_lora_params.
    • Included tests for update_weights_from_disk behavior with LoRA.
    • Verified the correct application order of LoRA within parallelization strategies.
  • tests/experimental/archon/test_archon_lora_checkpoint.py
    • Added unit tests for Qwen2StateDictAdapter's LoRA key conversions.
    • Included tests for PEFT adapter configuration generation.
    • Verified the functionality of is_lora_adapter_checkpoint for detecting PEFT checkpoints.
    • Tested state dict round-trip conversion with LoRA keys.
  • tests/experimental/archon/test_lora_linear.py
    • Added comprehensive unit tests for the LoRALinear module, including initialization, forward/backward passes, and dropout.
    • Tested from_linear conversion and AdapterModule protocol implementation.
    • Included tests for adapter utility functions like get_adapter_params and set_trainable_params.
    • Provided compatibility tests against HuggingFace PEFT's LoRA Linear module for forward pass, gradient flow, and scaling factor.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive LoRA (Low-Rank Adaptation) infrastructure compatible with FSDP2 and DTensor, which is a significant feature. The implementation correctly handles the complexities of distributed training, including a fix for a potential deadlock. It also adds PEFT-compatible checkpointing, which is great for interoperability. The code is generally well-structured, but there are a few areas for improvement regarding code style, clarity, and correctness in tests. Most notably, several new test files appear to be written for a different, outdated implementation of LoRALinear and will not work with the submitted code. This is a critical issue that needs to be addressed to ensure the new functionality is properly tested.

Comment on lines +69 to +80
def test_freeze_non_lora_params_keeps_only_adapter_trainable():
model = _ToyBlock()
engine = _make_engine(model, ["wq"])

engine._apply_lora()
engine._freeze_non_lora_params()

assert model.wq.weight.requires_grad is False
assert model.wq.lora_a.weight.requires_grad is True
assert model.wq.lora_b.weight.requires_grad is True
assert model.other.weight.requires_grad is False
assert model.inner.wv.weight.requires_grad is False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This test appears to be written for a different implementation of LoRALinear. It accesses model.wq.lora_a.weight, but the LoRALinear implementation in this PR does not have a lora_a submodule; it uses a plain tensor attribute _lora_a_weight. This test will fail and does not correctly validate the freezing logic for the submitted code. The test needs to be updated to access the LoRA weights correctly (e.g., model.wq._lora_a_weight).

@@ -0,0 +1,528 @@
"""Unit tests for LoRALinear module and adapter utilities."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The tests in this file appear to be written for a different implementation of LoRALinear. The current implementation in areal/experimental/models/archon/lora/lora_linear.py stores LoRA weights as plain tensor attributes (e.g., _lora_a_weight) to avoid FSDP hooks. However, these tests attempt to access them as if they were nn.Linear submodules (e.g., lora_linear.lora_a.weight).

This mismatch means the tests will fail and are not validating the submitted code. The tests need to be updated to reflect the actual LoRALinear implementation. For example, lora_linear.lora_b.weight should be lora_linear._lora_b_weight.

Comment on lines +198 to +206
@dataclass
class LoRAConfig:
enabled: bool
rank: int
alpha: float
target_modules: list[str]

self.lora_config = LoRAConfig(
enabled=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defining the LoRAConfig dataclass inside the __init__ method causes it to be redefined on every instantiation of ArchonEngine. It would be better for clarity, performance, and potential reuse to define it at the module level.

Comment on lines +513 to +515
)
sync_lora_grads(
self.model,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Local imports can obscure dependencies and are best avoided unless there's a specific reason like preventing circular imports. Consider moving this import to the top of the file for better code clarity and consistency.

Comment on lines +1085 to +1086
if self.lora_config is not None:
from areal.experimental.models.archon.lora import LoRALinear
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method _get_all_parameters is type-hinted to return list[nn.Parameter], but it is being extended with torch.Tensor objects from module.lora_parameters(). This creates a type inconsistency. Please update the type hint to reflect the actual return type, for example list[torch.Tensor] or typing.Union[nn.Parameter, torch.Tensor].

Suggested change
if self.lora_config is not None:
from areal.experimental.models.archon.lora import LoRALinear
def _get_all_parameters(self) -> list[torch.Tensor]:
params: list[torch.Tensor] = [p for m in self.model_parts for p in m.parameters()]

if self._tp_enabled:
result = self._tp_lora_forward(x, base_out)
if result.requires_grad and hasattr(self, "_debug_name"):
_name = self._debug_name
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable _name is assigned but never used. This appears to be dead code and should be removed.

forward_only: bool,
) -> list[torch.Tensor | dict[int, torch.Tensor]]:
results: list[torch.Tensor | dict[int, torch.Tensor]] = []
total_mbs = len(mb_list)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable total_mbs is defined but never used. This is dead code and should be removed to improve clarity.

total_mbs = len(mb_list)

for mb_item in mb_list:
for mb_idx, mb_item in enumerate(mb_list):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable mb_idx is defined but never used. Consider using for mb_item in mb_list: instead.

Suggested change
for mb_idx, mb_item in enumerate(mb_list):
for mb_item in mb_list:

with stats_tracker.scope("update"):
# Get current version for proximal approximation metrics
current_version = self.engine.get_version()
_n_mbs = len(mb_inputs.mbs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable _n_mbs is defined but never used. This is dead code and should be removed to improve clarity.

_n_mbs = len(mb_inputs.mbs)

for mb in mb_inputs.mbs:
for _mb_idx, mb in enumerate(mb_inputs.mbs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable _mb_idx is defined but never used. Consider using for mb in mb_inputs.mbs: instead.

Suggested change
for _mb_idx, mb in enumerate(mb_inputs.mbs):
for mb in mb_inputs.mbs:

@MikaStars39
Copy link
Copy Markdown
Author

Testing result on DAPO-math-17k and Deepseek-distill-qwen-1.5B:
image

settings:

experiment_name: gsm8k-grpo
trial_name: trial0

seed: 1
enable_offload: false
total_train_epochs: 10
tokenizer_path: ${actor.path}

cluster:
  n_nodes: 1
  n_gpus_per_node: 8
  fileroot: /your_path/qingyu/PeRL/outputs
  name_resolve:
    type: nfs
    nfs_record_root: ${cluster.fileroot}/name_resolve

allocation_mode: sglang:d4+archon:d2t2

scheduler:
  type: null

rollout:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  max_concurrent_rollouts: 128
  queue_size: null
  consumer_batch_size: ${train_dataset.batch_size}
  max_head_offpolicyness: 2
  enable_rollout_tracing: true
  scheduling_spec: ${actor.scheduling_spec}
  use_lora: true
  fileroot: ${cluster.fileroot}
  tokenizer_path: ${tokenizer_path}
  dump_to_file: true

gconfig:
  n_samples: 8
  min_new_tokens: 0
  max_new_tokens: 16384
  greedy: false
  temperature: 1.0
  lora_name: "lora"

actor:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  path: /your_path/qingyu/.cache/DeepSeek-R1-Distill-Qwen-1.5B
  init_from_scratch: false
  disable_dropout: true
  gradient_checkpointing: true
  dtype: bfloat16
  mb_spec:
    max_tokens_per_mb: 32768
  optimizer:
    type: adam
    lr: 2e-5
    weight_decay: 0.01
    beta1: 0.9
    beta2: 0.98
    eps: 1e-8
    lr_scheduler_type: constant
    gradient_clipping: 1.0
    warmup_steps_proportion: 0.001
  eps_clip: 0.2
  eps_clip_higher: 0.28
  temperature: ${gconfig.temperature}
  reward_scaling: 10.0
  reward_bias: -0.5
  kl_ctl: 0.0
  ppo_n_minibatches: 4
  recompute_logprob: true
  use_decoupled_loss: true
  behave_imp_weight_cap: 5.0
  reward_norm:
    mean_level: group
    std_level: group
    group_size: ${gconfig.n_samples}
  adv_norm:
    mean_level: batch
    std_level: batch
  max_new_tokens: ${gconfig.max_new_tokens}
  weight_update_mode: disk  # must be disk

  # lora
  use_lora: ${rollout.use_lora}
  peft_type: lora
  lora_rank: 32
  lora_alpha: 32
  target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
  scheduling_spec:
    - task_type: worker
      port_count: 2
      gpu: 1
      mem: 32
      cmd: python3 -m areal.infra.rpc.rpc_server
      env_vars: {}

ref:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  path: ${actor.path}
  init_from_scratch: false
  disable_dropout: true
  dtype: ${actor.dtype}
  mb_spec:
    max_tokens_per_mb: 32768
  optimizer: null
  scheduling_strategy:
    type: colocation
    target: actor
  scheduling_spec: ${actor.scheduling_spec}

# SGLang
sglang:
  model_path: ${actor.path}
  random_seed: ${seed}
  skip_tokenizer_init: true
  dtype: ${actor.dtype}
  max_running_requests: null
  context_length: 32768
  mem_fraction_static: 0.8
  # lora
  enable_lora: ${actor.use_lora}
  max_lora_rank: ${actor.lora_rank}

# datasets
train_dataset:
  batch_size: 64
  shuffle: true
  pin_memory: true
  num_workers: 4
  path: /your_path/qingyu/.cache/dapo-math-17k
  type: rl

valid_dataset:
  batch_size: 64
  pin_memory: true
  num_workers: 4
  path: /your_path/qingyu/.cache/aime-2024
  type: rl

# Utilities
saver:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}
  freq_epochs: null
  freq_steps: 32
  freq_secs: null

recover:
  mode: disabled
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}
  freq_epochs: null
  freq_steps: 32
  freq_secs: null

evaluator:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}
  freq_epochs: null
  freq_steps: 32
  freq_secs: null

stats_logger:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}

perf_tracer:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}
  enabled: false
  session_tracer:
    enabled: false


# lora_b initialized to zeros
assert torch.allclose(
lora_linear.lora_b.weight, torch.zeros_like(lora_linear.lora_b.weight)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Tests reference non-existent lora_a.weight / lora_b.weight attributes

LoRALinear stores LoRA weights as _lora_a_weight and _lora_b_weight via object.__setattr__. There are no lora_a or lora_b properties, sub-modules, or __getattr__ overrides in the class. Accessing lora_linear.lora_a will raise AttributeError through nn.Module.__getattr__.

This affects 40+ references across test_lora_linear.py and test_archon_engine_lora.pyall test files will fail at runtime.

# Test code expects:
lora_linear.lora_b.weight  # AttributeError: 'LoRALinear' has no attribute 'lora_b'

# Actual implementation stores:
object.__setattr__(self, "_lora_a_weight", _a)  # accessed as self._lora_a_weight

Suggestion: Either add @property accessors to LoRALinear that return proxy objects with a .weight attribute, or update all test references to use _lora_a_weight / _lora_b_weight directly.

if result.requires_grad and hasattr(self, "_debug_name"):
_name = self._debug_name

result.register_hook(lambda grad: grad)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: No-op debug hook left in hot path

This identity lambda creates a new closure and registers an autograd hook on every forward pass when _tp_enabled and _debug_name exist:

if result.requires_grad and hasattr(self, "_debug_name"):
    _name = self._debug_name  # captured but never used
    result.register_hook(lambda grad: grad)  # identity — adds overhead

This adds unnecessary overhead to every backward pass. The captured _name variable is unused. This appears to be leftover debug code.

Suggestion: Remove the debug hook entirely, or gate it behind a debug flag (e.g., TORCH_DISTRIBUTED_DEBUG).

for m in self.model_parts:
for module in m.modules():
if isinstance(module, LoRALinear):
params.extend(module.lora_parameters())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotation mismatch: _get_all_parameters returns mixed types

The method signature at line 1083 declares -> list[nn.Parameter] but when LoRA is enabled, it extends with plain torch.Tensor objects from lora_parameters():

def _get_all_parameters(self) -> list[nn.Parameter]:  # annotation says nn.Parameter
    params = [p for m in self.model_parts for p in m.parameters()]
    if self.lora_config is not None:
        ...
        params.extend(module.lora_parameters())  # returns list[torch.Tensor], not nn.Parameter
    return params

Downstream consumers (fsdp2_clip_grad_norm, optimizer creation, gradient zeroing) may assume nn.Parameter semantics. Plain tensors behave differently for .grad accumulation and optimizer param group tracking.

Suggestion: Update the return type to list[nn.Parameter | torch.Tensor] and audit all call sites to ensure they handle both types correctly.

"""Return list of adapter parameter names relative to this module.

Returns:
List of parameter names (e.g., ["lora_a.weight", "lora_b.weight"])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Docstring example doesn't match actual return values

The AdapterModule protocol docstring says:

Returns:
    List of parameter names (e.g., ["lora_a.weight", "lora_b.weight"])

But the actual LoRALinear.adapter_params() implementation returns ["_lora_a_weight", "_lora_b_weight"] (underscore-prefixed, no .weight suffix). The get_adapter_params function uses these names with getattr(), so the implementation works correctly, but the protocol docstring is misleading and may confuse future implementors.

Suggestion: Update the docstring example to ["_lora_a_weight", "_lora_b_weight"] or note that the naming convention depends on the implementation.

Copy link
Copy Markdown
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MikaStars39 , thanks for the great contribution! It reduces many of our workloads.

There are still several critical issues to be addressed, as listed in the comments. We need to be more rigorous about the forward/backward computation with a multi-layer model, and with parallelism enabled.

In addition, please reabse the latest main and squash the commit messages. There are some pre-existing changes caused by merging in this PR.

Comment on lines +195 to +210
if hasattr(config, "use_lora") and config.use_lora:
from dataclasses import dataclass

@dataclass
class LoRAConfig:
enabled: bool
rank: int
alpha: float
target_modules: list[str]

self.lora_config = LoRAConfig(
enabled=True,
rank=config.lora_rank,
alpha=float(config.lora_alpha),
target_modules=config.target_modules if config.target_modules else [],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should place this config object in cli_args.py and totally refactor the use_lora field in TrainEngineConfig.

Comment on lines +92 to +108
# Save weights (only rank 0)
if rank == 0:
weights_path = os.path.join(path, "adapter_model.safetensors")
save_file(peft_state, weights_path)
logger.info(f"Saved {len(peft_state)} adapter tensors to {weights_path}")

# Determine target modules from actual adapter parameters
target_modules = set()
for key in adapter_params:
parts = key.split(".")
for i, part in enumerate(parts):
is_lora = part in ("lora_a", "lora_b") or part.startswith("_lora_")
if is_lora and i > 0:
module_name = parts[i - 1]
target_modules.add(module_name)
break

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the current implementation only supports TP/CP/FSDP. Should marked with TODO

Comment on lines 74 to 112
results: list[torch.Tensor | dict[int, torch.Tensor]] = []
total_mbs = len(mb_list)

for mb_item in mb_list:
for mb_idx, mb_item in enumerate(mb_list):
inputs, ctx = self.prepare_inputs_fn(mb_item)

tree_attn_meta = None
if ctx.trie_node is not None:
padded_size = mb_item.padded_to_length
assert padded_size is not None
tree_attn_meta = TreeAttentionMeta.from_trie(
ctx.trie_node, padded_size, inputs["input_ids"].device
)
# Tree attention uses tree_attn_meta instead of cu_seqlens;
# create dummy cu_seqlens for model interface compatibility.
seq_len = inputs["input_ids"].shape[-1]
cu_seqlens = torch.tensor(
[0, seq_len], dtype=torch.int32, device=inputs["input_ids"].device
)
max_seqlen = seq_len
else:
cu_seqlens = inputs["cu_seqlens"]
max_seqlen = int(inputs["max_seqlen"])

logits = self.model(
inputs["input_ids"],
inputs["position_ids"],
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
tree_attn_meta=tree_attn_meta,
)
logits = logits.squeeze(0)
del tree_attn_meta

result = process_output_fn(logits, ctx.to_dict())

if result is not None:
if forward_only:
# Result can be a tensor or dict (for tree training)
if isinstance(result, dict):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should revert

Comment on lines +235 to 253
save_model_to_hf(engine, meta.path, engine.tokenizer, None)

if dist.get_rank() == 0:
ready_path = os.path.join(meta.path, WEIGHT_UPDATE_READY_FILE)
ready_tmp_path = ready_path + ".tmp"
ready_timestamp = str(datetime.now().timestamp())
with open(ready_tmp_path, "w") as f:
f.write(ready_timestamp)
os.replace(ready_tmp_path, ready_path)

update_name = names.update_weights_from_disk(
engine.config.experiment_name,
engine.config.trial_name,
engine.get_version(),
)
name_resolve.add(
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
update_name, ready_timestamp, keepalive_ttl=600
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should rebase from main

Comment on lines +232 to +239

local_w = getattr(linear.weight, "_local_tensor", linear.weight)
_a = torch.empty(rank, linear.in_features, device=local_w.device, dtype=local_w.dtype)
_b = torch.empty(linear.out_features, rank, device=local_w.device, dtype=local_w.dtype)
_a.requires_grad_(True)
_b.requires_grad_(True)
object.__setattr__(lora_linear, "_lora_a_weight", _a)
object.__setattr__(lora_linear, "_lora_b_weight", _b)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are init parameters all-reduced when TP is enabled?

Comment on lines 321 to 327
current_version = self.engine.get_version()
_n_mbs = len(mb_inputs.mbs)

for mb in mb_inputs.mbs:
for _mb_idx, mb in enumerate(mb_inputs.mbs):
train_stat = self.engine.train_batch(
mb,
loss_fn=functools.partial(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should revert.

Comment on lines 424 to 425

# Wait for async checkpoint staging to complete before modifying parameters
self.saver.maybe_wait_for_staging()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should revert

Comment on lines +14 to +23

class _ToyBlock(nn.Module):
def __init__(self):
super().__init__()
self.wq = nn.Linear(8, 8)
self.other = nn.Linear(8, 8)
self.inner = nn.Module()
self.inner.wv = nn.Linear(8, 8)


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a multi-layer, end-to-end forward/backward precision test for scientific rigor. You can refer to existing tests for the archon engine.

Comment on lines +20 to +27
# Try to import PEFT for compatibility tests
try:
import peft # noqa: F401

PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to add an end-to-end test to load a model from native peft/transformers API.

@MikaStars39
Copy link
Copy Markdown
Author

Hi @MikaStars39 , thanks for the great contribution! It reduces many of our workloads.

There are still several critical issues to be addressed, as listed in the comments. We need to be more rigorous about the forward/backward computation with a multi-layer model, and with parallelism enabled.

In addition, please reabse the latest main and squash the commit messages. There are some pre-existing changes caused by merging in this PR.

Hi @garrett4wade , thanks for the code review! I will address these issues and rebase / squash. Also there will be a new training test to examine if the training works well after fixing these issues.

@garrett4wade
Copy link
Copy Markdown
Collaborator

nit: just a link to the issue - this PR fixes #1040

zhennan0521 pushed a commit to zhennan0521/AReaL that referenced this pull request Mar 31, 2026
…frastructure with FSDP2-DTensor deadlock fix
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 8, 2026

This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days.

Please add a comment or push new commits to keep it active.

Thank you for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants