From 4f7686dfe2f7d440db85870840315f24dfa5610a Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 13 Apr 2026 18:03:20 +0800 Subject: [PATCH 001/140] feat: speculative decode init --- areal/api/cli_args.py | 112 +++++ areal/api/io_struct.py | 6 + areal/engine/megatron_engine.py | 269 +++++++++++- areal/engine/megatron_utils/megatron.py | 60 +++ .../megatron_utils/packed_context_parallel.py | 16 + areal/engine/sglang_remote.py | 17 + areal/infra/remote_inf_engine.py | 10 + areal/models/mcore/registry.py | 17 + areal/trainer/ppo/actor.py | 9 + areal/trainer/rl_trainer.py | 8 + areal/workflow/rlvr.py | 32 ++ docs/en/_toc.yml | 1 + docs/en/cli_reference.md | 131 +++--- docs/en/tutorial/speculative_decoding.md | 273 ++++++++++++ docs/zh/_toc.yml | 1 + docs/zh/cli_reference.md | 131 +++--- docs/zh/tutorial/speculative_decoding.md | 262 +++++++++++ examples/math/gsm8k_grpo_megatron_eagle.yaml | 200 +++++++++ tests/speculative_decoding/__init__.py | 0 .../config_spec_only.yaml | 170 ++++++++ .../config_spec_with_mtp.yaml | 177 ++++++++ tests/speculative_decoding/entrypoint.py | 253 +++++++++++ .../test_speculative_decoding.py | 411 ++++++++++++++++++ 23 files changed, 2450 insertions(+), 116 deletions(-) create mode 100644 docs/en/tutorial/speculative_decoding.md create mode 100644 docs/zh/tutorial/speculative_decoding.md create mode 100644 examples/math/gsm8k_grpo_megatron_eagle.yaml create mode 100644 tests/speculative_decoding/__init__.py create mode 100644 tests/speculative_decoding/config_spec_only.yaml create mode 100644 tests/speculative_decoding/config_spec_with_mtp.yaml create mode 100644 tests/speculative_decoding/entrypoint.py create mode 100644 tests/speculative_decoding/test_speculative_decoding.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index e08c852ec4..aa7d7042f5 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -898,6 +898,29 @@ class MegatronEngineConfig: }, ) + # MTP (Multi-Token Prediction) Configuration + mtp_num_layers: int = field( + default=0, + metadata={ + "help": "Number of MTP (Multi-Token Prediction) layers for speculative decoding training. " + "0 means MTP is disabled." + }, + ) + mtp_loss_scaling_factor: float = field( + default=0.1, + metadata={ + "help": "Scaling factor for MTP auxiliary loss. Controls the weight of MTP loss " + "relative to the main RL loss." + }, + ) + mtp_detach_heads: bool = field( + default=True, + metadata={ + "help": "Whether to detach hidden states before passing to MTP heads in MegatronEngine. " + "When True, MTP loss gradients only update MTP parameters." + }, + ) + class SchedulingStrategyType(str, Enum): separation = "separation" @@ -1316,6 +1339,37 @@ class PPOActorConfig(TrainEngineConfig): metadata={"help": "Maximum number of new tokens to generate"}, ) + # MTP (Multi-Token Prediction) Online Training + enable_mtp_training: bool = field( + default=False, + metadata={ + "help": "Enable MTP (Multi-Token Prediction) online training during RL. " + "When enabled, MTP layers are trained alongside the main policy model " + "to keep the draft model aligned with the evolving policy." + }, + ) + mtp_num_layers: int = field( + default=1, + metadata={ + "help": "Number of MTP layers to train. Must match the model's MTP architecture." + }, + ) + mtp_loss_scaling_factor: float = field( + default=0.1, + metadata={ + "help": "Scaling factor for MTP auxiliary loss relative to the main RL loss." + }, + ) + mtp_detach_heads: bool = field( + default=True, + metadata={ + "help": "Whether to detach hidden states before passing to MTP heads. " + "When True (recommended for RL), MTP loss gradients only update MTP parameters, " + "preventing the MTP auxiliary loss from corrupting the main policy gradients. " + "When False, MTP loss gradients also flow back to the main model." + }, + ) + def should_compute_prox_logp(self) -> bool: """Determine if forward pass is needed for proximal log-probabilities. @@ -1373,6 +1427,19 @@ def __post_init__(self): "Please set `actor.use_decoupled_loss=false` in your configuration." ) + # Validate MTP configuration + if self.enable_mtp_training: + if self.mtp_num_layers <= 0: + raise ValueError( + f"mtp_num_layers must be > 0 when enable_mtp_training is True, " + f"got {self.mtp_num_layers}." + ) + if not (0 < self.mtp_loss_scaling_factor <= 1.0): + raise ValueError( + f"mtp_loss_scaling_factor must be in (0, 1.0], " + f"got {self.mtp_loss_scaling_factor}." + ) + super().__post_init__() @@ -1579,6 +1646,44 @@ class SGLangConfig: # Internal field, not exposed to users. enable_return_routed_experts: bool = False + # Speculative Decoding Configuration + speculative_algorithm: str | None = field( + default=None, + metadata={ + "help": "Speculative decoding algorithm. Options: 'EAGLE', 'EAGLE3'. None disables speculative decoding." + }, + ) + speculative_draft_model_path: str | None = field( + default=None, + metadata={"help": "Path to the draft model for speculative decoding."}, + ) + speculative_num_steps: int = field( + default=3, + metadata={"help": "Number of speculative decoding draft steps."}, + ) + speculative_eagle_topk: int = field( + default=1, + metadata={"help": "Top-k value for EAGLE draft token selection."}, + ) + speculative_num_draft_tokens: int = field( + default=4, + metadata={"help": "Number of draft tokens per speculative step."}, + ) + speculative_attention_mode: str | None = field( + default=None, + metadata={ + "help": "Attention mode for speculative decoding. E.g., 'full', 'sparse'." + }, + ) + enable_multi_layer_eagle: bool = False + enable_draft_weights_cpu_backup: bool | None = field( + default=None, + metadata={ + "help": "Keep draft model weights on CPU as backup during GPU offload cycles. " + "Essential for colocated training+inference mode to prevent draft weight loss." + }, + ) + # Use staticmethod to make OmegaConf happy. @staticmethod def build_cmd( @@ -1630,6 +1735,13 @@ def build_args( ) args.pop("enable_multithread_load", None) + # enable_draft_weights_cpu_backup: pass to SGLang ServerArgs constructor if set. + # Essential for colocated training+inference mode to prevent draft weight loss + # during GPU offload cycles. If None, let SGLang use its default. + draft_cpu_backup = args.pop("enable_draft_weights_cpu_backup", None) + if draft_cpu_backup is not None: + args["enable_draft_weights_cpu_backup"] = draft_cpu_backup + args = dict( # Model and tokenizer tokenizer_path=sglang_config.model_path, diff --git a/areal/api/io_struct.py b/areal/api/io_struct.py index e63f849230..68ae9b9c38 100644 --- a/areal/api/io_struct.py +++ b/areal/api/io_struct.py @@ -80,6 +80,10 @@ class ModelResponse: # MoE routing (only populated when return_routed_experts=True) routed_experts: np.ndarray | None = None + # Speculative decoding statistics + spec_accept_token_num: int = 0 + spec_draft_token_num: int = 0 + @property def input_len(self) -> int: return len(self.input_tokens) @@ -283,6 +287,8 @@ class HttpGenerationResult: output_logprobs: list[float] stop_reason: str routed_experts: np.ndarray | None = None + spec_accept_token_num: int | None = None + spec_draft_token_num: int | None = None @dataclass diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index eb486d8530..80e00dd998 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -173,6 +173,39 @@ def __init__(self, config: TrainEngineConfig): self.bridge_cls: str = getattr(self.mcore_config, "bridge_type", "mbridge") self.bridge_lora: MegatronBridgeLoRA | None = None + # MTP (Multi-Token Prediction) configuration + self.enable_mtp_training: bool = getattr( + self.config, "enable_mtp_training", False + ) + self.mtp_num_layers: int = getattr(self.config, "mtp_num_layers", 0) + self.mtp_loss_scaling_factor: float = getattr( + self.config, "mtp_loss_scaling_factor", 0.1 + ) + self.mtp_detach_heads: bool = getattr(self.config, "mtp_detach_heads", True) + self._mtp_loss_value: float = 0.0 + self._mtp_layers_verified: bool = False + if self.enable_mtp_training: + self.logger.info( + f"[MTPTrain] MTP online training ENABLED: " + f"num_layers={self.mtp_num_layers}, " + f"loss_scaling_factor={self.mtp_loss_scaling_factor}, " + f"detach_heads={self.mtp_detach_heads}" + ) + try: + import megatron.core.transformer.multi_token_prediction # noqa: F401 + + self.logger.info( + "[MTPTrain] Verified megatron-core MTP module available. " + "Gradient isolation (embedding detach + functional_call lm_head) " + "is handled internally by megatron-core MultiTokenPrediction module." + ) + except ImportError: + self.logger.error( + "[MTPTrain] megatron-core MTP module not found! " + "MTP training requires megatron-core >= 0.12.0. " + "Gradient isolation will NOT be applied, which corrupts RL training." + ) + def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): if parallel_strategy is None: parallel_strategy = ParallelStrategy() @@ -307,6 +340,19 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self._check_and_apply_fp8_config() self._validate_fp8_consistency() + # Propagate MTP config to mcore_config for model creation + if self.enable_mtp_training: + self.mcore_config.mtp_num_layers = self.mtp_num_layers + self.mcore_config.mtp_loss_scaling_factor = self.mtp_loss_scaling_factor + if hasattr(self.mcore_config, "mtp_detach_heads"): + self.mcore_config.mtp_detach_heads = self.mtp_detach_heads + self.logger.info( + f"[MTPTrain] Propagated MTP config to mcore_config: " + f"mtp_num_layers={self.mtp_num_layers}, " + f"mtp_loss_scaling_factor={self.mtp_loss_scaling_factor}, " + f"mtp_detach_heads={self.mtp_detach_heads}" + ) + with self.device: models = make_mcore_model( hf_config=self.hf_config, @@ -398,6 +444,28 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): model_config.param_sync_func = model_config.param_sync_func[0] model_config.finalize_model_grads_func = finalize_model_grads self._create_optimizer(ft_spec) + + if self.enable_mtp_training and not self._mtp_layers_verified: + mtp_param_count = 0 + for module in modules: + for name, param in module.named_parameters(): + if ".mtp." in name: + mtp_param_count += param.numel() + if mtp_param_count == 0: + self.logger.error( + "[MTPTrain] enable_mtp_training=True but NO MTP parameters found in model! " + "Possible causes: 1) mtp_num_layers=0 in model config; " + "2) Model checkpoint does not contain MTP layers; " + "3) mcore_config.mtp_num_layers not set correctly. " + "MTP loss will NOT be computed." + ) + else: + self._mtp_layers_verified = True + self.logger.info( + f"[MTPTrain] Verified MTP parameters in model: " + f"total_mtp_params={mtp_param_count / 1e6:.2f}M" + ) + self._initialized = True def _build_hf_mcore_bridge(self): @@ -635,6 +703,114 @@ def optimizer_zero_grad(self): for model in self.model: model.zero_grad_buffer() + @staticmethod + def _roll_tensor_packed( + tensor: torch.Tensor, shift: int, cu_seqlens: torch.Tensor + ) -> torch.Tensor: + """Roll tensor within each packed sequence boundary. + + In sequence packing mode, multiple sequences are concatenated. A naive + torch.roll would leak tokens across sequence boundaries. This function + rolls within each sequence independently and zeros out boundary positions. + """ + result = torch.zeros_like(tensor) + num_seqs = cu_seqlens.shape[0] - 1 + for i in range(num_seqs): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + seq_slice = tensor[..., start:end] + rolled = torch.roll(seq_slice, shifts=shift, dims=-1) + if shift < 0: + rolled[..., shift:] = 0 # zero out wrapped-around positions at end + else: + rolled[..., :shift] = 0 + result[..., start:end] = rolled + return result + + def _collect_mtp_loss(self) -> dict[str, float]: + """Collect MTP loss from Megatron-Core's MTPLossLoggingHelper after forward-backward. + + The MTP loss is computed and backpropagated by Megatron-Core's MTP module + during the forward-backward pass via MTPLossAutoScaler. This function only + collects the loss VALUE for logging and monitoring purposes. + + IMPORTANT: All CP ranks must participate in the all-reduce to avoid deadlock. + The gate condition uses is_pipeline_last_stage() instead of + is_mp_src_rank_with_outputs() to ensure all CP ranks enter the all-reduce. + """ + mtp_stats = {} + try: + from megatron.core.transformer.multi_token_prediction import ( + MTPLossLoggingHelper, + ) + + tracker = MTPLossLoggingHelper.tracker + if tracker and "values" in tracker: + values = tracker["values"] + + is_last_pp_stage = mpu.is_pipeline_last_stage(ignore_virtual=True) + + if tracker.get("reduce_group") is not None: + import torch.distributed + + torch.distributed.all_reduce(values, group=tracker["reduce_group"]) + if tracker.get("avg_group") is not None: + import torch.distributed + + torch.distributed.all_reduce( + values, + group=tracker["avg_group"], + op=torch.distributed.ReduceOp.AVG, + ) + + mtp_loss_value = values.item() + self._mtp_loss_value = mtp_loss_value + + if is_last_pp_stage: + mtp_stats["mtp_loss"] = mtp_loss_value + + if math.isnan(mtp_loss_value) or math.isinf(mtp_loss_value): + self.logger.error( + f"[MTPTrain] MTP loss is NaN/Inf! value={mtp_loss_value}. " + f"Check MTP label construction and model configuration." + ) + elif mtp_loss_value < 0 or mtp_loss_value > 100: + self.logger.warning( + f"[MTPTrain] MTP loss {mtp_loss_value:.6f} outside expected range [0, 100]." + ) + else: + self.logger.info( + f"[MTPTrain] MTP loss={mtp_loss_value:.6f}, " + f"scaling_factor={self.mtp_loss_scaling_factor}, " + f"scaled_mtp_loss={mtp_loss_value * self.mtp_loss_scaling_factor:.6f}, " + f"is_last_pp_stage={is_last_pp_stage}" + ) + + MTPLossLoggingHelper.clean_loss_in_tracker() + else: + if self.enable_mtp_training: + self.logger.warning( + "[MTPTrain] MTP loss tracker is empty after forward-backward " + "even though enable_mtp_training=True. Possible causes: " + "1) Model does not have MTP layers; " + "2) mtp_kwargs were not passed correctly; " + "3) Megatron-Core version mismatch. " + "Verify model architecture and mtp_num_layers config." + ) + + except ImportError: + self.logger.warning( + "[MTPTrain] Cannot import MTPLossLoggingHelper from megatron.core. " + "MTP loss collection disabled. Ensure megatron-core >= 0.12.0 " + "for MTP with gradient isolation support." + ) + except Exception as e: + self.logger.error( + f"[MTPTrain] Error collecting MTP loss: {e}", exc_info=True + ) + + return mtp_stats + def optimizer_step(self): with trace_scope("megatron_engine.step"): update_successful, grad_norm, _ = self.optimizer.step() @@ -687,7 +863,56 @@ def forward_step(batch_iter, model): mb_input.padded_mb.update(tree_kwargs) tree_attn_keys = list(tree_kwargs.keys()) - output = packed_context_parallel_forward(model, mb_input.padded_mb) + # Build MTP kwargs if MTP training is enabled + extra_block_kwargs = None + if self.enable_mtp_training: + mtp_labels = mb_input.padded_mb["input_ids"] + + loss_mask = mb_input.padded_mb.get("loss_mask", None) + mtp_loss_mask = None + if loss_mask is not None: + cu_seqlens = mb_input.padded_mb.get("cu_seqlens", None) + if cu_seqlens is not None: + mask_1 = self._roll_tensor_packed( + loss_mask, shift=-1, cu_seqlens=cu_seqlens + ) + mask_2 = self._roll_tensor_packed( + mask_1, shift=-1, cu_seqlens=cu_seqlens + ) + else: + mask_1 = torch.roll(loss_mask, shifts=-1, dims=-1) + mask_1[..., -1] = 0 + mask_2 = torch.roll(mask_1, shifts=-1, dims=-1) + mask_2[..., -1] = 0 + mtp_loss_mask = mask_1 * mask_2 + valid_mtp_tokens = mtp_loss_mask.sum().item() + total_mtp_tokens = mtp_loss_mask.numel() + self.logger.info( + f"[MTPTrain] MTP loss mask: valid_tokens={valid_mtp_tokens}, " + f"total_tokens={total_mtp_tokens}, " + f"mask_ratio={valid_mtp_tokens / max(total_mtp_tokens, 1):.4f}" + ) + else: + self.logger.warning( + "[MTPTrain] loss_mask is None; MTP loss will be computed over " + "all positions including padding. This may lead to incorrect " + "MTP loss values. Ensure loss_mask is provided in the input." + ) + + mtp_kwargs = {"mtp_labels": mtp_labels} + if mtp_loss_mask is not None: + mtp_kwargs["mtp_loss_mask"] = mtp_loss_mask + extra_block_kwargs = {"mtp_kwargs": mtp_kwargs} + + self.logger.info( + f"[MTPTrain] Forward step: mtp_labels shape={mtp_labels.shape}, " + f"dtype={mtp_labels.dtype}, " + f"has_mtp_loss_mask={mtp_loss_mask is not None}, " + f"mtp_num_layers={self.mtp_num_layers}" + ) + output = packed_context_parallel_forward( + model, mb_input.padded_mb, extra_block_kwargs=extra_block_kwargs + ) # Release tree attention metadata after forward pass for key in tree_attn_keys: @@ -763,8 +988,19 @@ def process_output( self.forward_backward_batch(mb_list, process_output, forward_only=False) - # Step 4: Optimizer step - return self.optimizer_step() + # Step 4: Collect MTP loss after forward-backward + mtp_loss_stats = {} + if self.enable_mtp_training: + mtp_loss_stats = self._collect_mtp_loss() + + # Step 5: Optimizer step + train_stats = self.optimizer_step() + + # Merge MTP stats into train stats + if mtp_loss_stats: + train_stats.update(mtp_loss_stats) + + return train_stats @torch.no_grad() def eval_batch( @@ -797,7 +1033,15 @@ def process_output( self.forward_backward_batch(mb_list, process_output, forward_only=True) - # Step 4: Aggregate losses + # Step 4: Collect MTP loss during eval if enabled + if self.enable_mtp_training: + mtp_loss_stats = self._collect_mtp_loss() + if mtp_loss_stats: + self.logger.info( + f"[MTPTrain] Eval MTP loss: {mtp_loss_stats.get('mtp_loss', 'N/A')}" + ) + + # Step 5: Aggregate losses if mpu.is_pipeline_last_stage(): return aggregate_eval_losses(losses, mpu.get_data_parallel_group()) return None @@ -1404,9 +1648,14 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: buffer_size = 0 converted_named_tensors = [] + mtp_param_count = 0 + mtp_param_bytes = 0 for name, param in get_named_parameters(self.model, num_moe_experts): if ".experts." in name: continue + if ".mtp." in name: + mtp_param_count += 1 + mtp_param_bytes += param.numel() * param.element_size() if self.config.use_lora and ( ".adapter." not in name or not getattr(param, "requires_grad", False) ): @@ -1429,6 +1678,18 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: meta.version, ) + if mtp_param_count > 0: + self.logger.info( + f"[MTPTrain] Weight sync: {mtp_param_count} MTP parameters " + f"({mtp_param_bytes / 1024 / 1024:.2f} MB) synced to inference engine " + f"at version={meta.version}" + ) + elif self.enable_mtp_training: + self.logger.warning( + f"[MTPTrain] enable_mtp_training=True but 0 MTP parameters found " + f"during weight sync at version={meta.version}. " + f"MTP draft model weights will NOT be updated!" + ) dist.barrier(group=self.cpu_group) buffer_size = 0 diff --git a/areal/engine/megatron_utils/megatron.py b/areal/engine/megatron_utils/megatron.py index 85b445b0ac..55d846fb34 100644 --- a/areal/engine/megatron_utils/megatron.py +++ b/areal/engine/megatron_utils/megatron.py @@ -159,6 +159,46 @@ def remove_padding( return param + +def _convert_mtp_layer_to_hf( + name: str, + param: Parameter | Tensor | FP8BlockwiseTensorHelper, + tf_config: TransformerConfig, +) -> list[tuple[str, Tensor]] | None: + """Convert MCore MTP layer parameter names to HuggingFace format. + + MCore MTP layers follow the naming pattern: + module.module.decoder.mtp_layers.{layer_idx}.{submodule}.{param} + which maps to HF format: + model.mtp_layers.{layer_idx}.{submodule}.{param} + + Returns a list of (hf_name, param) tuples if the parameter is an MTP + parameter, or None if it is not. + """ + import re + mtp_match = re.match( + r"module\.module\.decoder\.mtp_layers\.(\d+)\.(.+)", name + ) + if mtp_match is None: + return None + + layer_idx = int(mtp_match.group(1)) + remainder = mtp_match.group(2) + + # Map common MCore submodule names to HF names + hf_remainder = remainder + + # enorm / hnorm -> input_layernorm / post_attention_layernorm equivalent + hf_remainder = hf_remainder.replace("enorm.weight", "enorm.weight") + hf_remainder = hf_remainder.replace("hnorm.weight", "hnorm.weight") + + # Note: Some models (e.g., MiMo) may need column-half swap for eh_proj. + # This should be handled in model-specific conversion functions, not here. + # The generic MTP converter passes eh_proj through unchanged. + + hf_name = f"model.mtp_layers.{layer_idx}.{hf_remainder}" + return [(hf_name, param)] + # Adapted from slime def convert_qwen3moe_to_hf( tf_config: TransformerConfig, @@ -172,6 +212,11 @@ def convert_qwen3moe_to_hf( if name == "module.module.decoder.final_layernorm.weight": return [("model.norm.weight", param)] + # Check for MTP layer parameters + mtp_result = _convert_mtp_layer_to_hf(name, param, tf_config) + if mtp_result is not None: + return mtp_result + try: head_dim = ( tf_config.kv_channels @@ -329,6 +374,11 @@ def convert_qwen2_to_hf( if name == "module.module.decoder.final_layernorm.weight": return [("model.norm.weight", param)] + # Check for MTP layer parameters + mtp_result = _convert_mtp_layer_to_hf(name, param, tf_config) + if mtp_result is not None: + return mtp_result + try: head_dim = ( tf_config.kv_channels @@ -419,6 +469,11 @@ def convert_deepseekv3_to_hf( if name == "module.module.decoder.final_layernorm.weight": return [("model.norm.weight", param)] + # Check for MTP layer parameters + mtp_result = _convert_mtp_layer_to_hf(name, param, tf_config) + if mtp_result is not None: + return mtp_result + try: head_dim = ( tf_config.kv_channels @@ -593,6 +648,11 @@ def convert_bailingmoe_to_hf( if name == "module.module.decoder.final_layernorm.weight": return [("model.norm.weight", param)] + # Check for MTP layer parameters + mtp_result = _convert_mtp_layer_to_hf(name, param, tf_config) + if mtp_result is not None: + return mtp_result + decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)" match = re.match(decoder_layers_pattern, name) if match: diff --git a/areal/engine/megatron_utils/packed_context_parallel.py b/areal/engine/megatron_utils/packed_context_parallel.py index 70e5d1c173..0653f43cb5 100644 --- a/areal/engine/megatron_utils/packed_context_parallel.py +++ b/areal/engine/megatron_utils/packed_context_parallel.py @@ -123,6 +123,7 @@ def postprocess_packed_seqs_context_parallel( def packed_context_parallel_forward( model: torch.nn.Module, input_: dict[str, Any], + extra_block_kwargs: dict[str, Any] | None = None, ): input_ids = input_["input_ids"] position_ids = input_["position_ids"] @@ -144,6 +145,20 @@ def packed_context_parallel_forward( ) input_ids = input_ids.contiguous() + # Also split MTP labels with the same CP logic if present + if extra_block_kwargs and "mtp_kwargs" in extra_block_kwargs: + mtp_kwargs = extra_block_kwargs["mtp_kwargs"] + if "mtp_labels" in mtp_kwargs: + mtp_labels_split, _ = preprocess_packed_seqs_context_parallel( + mtp_kwargs["mtp_labels"], cu_seqlens + ) + mtp_kwargs["mtp_labels"] = mtp_labels_split.contiguous() + if "mtp_loss_mask" in mtp_kwargs: + mtp_mask_split, _ = preprocess_packed_seqs_context_parallel( + mtp_kwargs["mtp_loss_mask"], cu_seqlens + ) + mtp_kwargs["mtp_loss_mask"] = mtp_mask_split.contiguous() + # Pass tree_triton_data as attention_mask if present (for Triton tree attention) # Otherwise use the attention_mask from input (could be dense tensor for flex attention) final_attention_mask = ( @@ -156,6 +171,7 @@ def packed_context_parallel_forward( attention_mask=final_attention_mask, position_ids=position_ids, packed_seq_params=packed_seq_params, + extra_block_kwargs=extra_block_kwargs, ) except Exception as e: raise RuntimeError( diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 9804d08c3f..49b01fdbc9 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -31,8 +31,11 @@ from areal.infra.platforms import current_platform from areal.infra.utils.launcher import TRITON_CACHE_PATH from areal.utils import perf_tracer, stats_tracker +from areal.utils.logging import getLogger from areal.utils.network import format_host_for_url +logger = getLogger("SGLangRemote") + class SGLangBackend: """SGLang-specific backend implementation for remote inference.""" @@ -105,12 +108,24 @@ def parse_generation_response( pybase64.b64decode(routed_experts.encode("utf-8")), dtype=np.int32 ).reshape(num_sgl_token, -1) + # Extract speculative decoding statistics if available + spec_accept_token_num = meta_info.get("spec_accept_token_num", None) + spec_draft_token_num = meta_info.get("spec_draft_token_num", None) + if spec_accept_token_num is not None and spec_draft_token_num is not None: + if spec_draft_token_num > 0: + accept_rate = spec_accept_token_num / spec_draft_token_num + logger.debug( + f"[SpecDec] SGLang response: accept={spec_accept_token_num}, " + f"draft={spec_draft_token_num}, rate={accept_rate:.4f}" + ) if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): return HttpGenerationResult( output_tokens=[], output_logprobs=[], stop_reason=stop_reason, routed_experts=routed_experts, + spec_accept_token_num=spec_accept_token_num, + spec_draft_token_num=spec_draft_token_num, ) output_tokens = [x[1] for x in meta_info["output_token_logprobs"]] @@ -121,6 +136,8 @@ def parse_generation_response( output_logprobs=output_logprobs, stop_reason=stop_reason, routed_experts=routed_experts, + spec_accept_token_num=spec_accept_token_num, + spec_draft_token_num=spec_draft_token_num, ) def build_disk_weight_update_requests( diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 835537cfe7..80f631ef81 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -763,6 +763,8 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: accumulated_output_logprobs = [] accumulated_versions = [] accumulated_routed_experts: list[np.ndarray] = [] + accumulated_spec_accept_tokens = 0 + accumulated_spec_draft_tokens = 0 # A single "rid" shares the same server to allow KV cache reuse if req.rid in self.rid_to_address: @@ -840,6 +842,12 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: if gen_result.routed_experts is not None: accumulated_routed_experts.append(gen_result.routed_experts) + # Accumulate speculative decoding statistics + if gen_result.spec_accept_token_num is not None: + accumulated_spec_accept_tokens += gen_result.spec_accept_token_num + if gen_result.spec_draft_token_num is not None: + accumulated_spec_draft_tokens += gen_result.spec_draft_token_num + # Update request for next iteration req.input_ids += gen_result.output_tokens req.gconfig.max_new_tokens -= len(gen_result.output_tokens) @@ -878,6 +886,8 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: tokenizer=req.tokenizer, processor=req.processor, routed_experts=accumulated_routed_experts, + spec_accept_token_num=accumulated_spec_accept_tokens, + spec_draft_token_num=accumulated_spec_draft_tokens, ) return response diff --git a/areal/models/mcore/registry.py b/areal/models/mcore/registry.py index 9caaf37e14..477b476545 100644 --- a/areal/models/mcore/registry.py +++ b/areal/models/mcore/registry.py @@ -275,6 +275,22 @@ def make_mcore_model( "Virtual pipeline parallelism requires mbridge-backed models." ) transformer_layer_spec = make_mcore_layer_specs(hf_config, tf_config) + + # Build MTP block spec if MTP is configured + mtp_block_spec = None + mtp_num_layers = getattr(tf_config, "mtp_num_layers", 0) + if mtp_num_layers > 0: + try: + from megatron.core.models.gpt.gpt_layer_specs import get_mtp_block_spec + mtp_block_spec = get_mtp_block_spec(tf_config, transformer_layer_spec) + logger.info( + f"[MTPTrain] Created MTP block spec with {mtp_num_layers} layers" + ) + except ImportError: + logger.warning( + "[MTPTrain] Cannot import get_mtp_block_spec from megatron.core. " + "MTP layers will not be created. Ensure megatron-core >= 0.11.0." + ) rope_scaling_args = {} if hf_config.rope_scaling is not None: if hf_config.rope_scaling["type"] != "linear": @@ -297,6 +313,7 @@ def make_mcore_model( rotary_base=hf_config.rope_theta, **rope_scaling_args, # vp_stage=None TODO: virtual pipeline parallel + **({"mtp_block_spec": mtp_block_spec} if mtp_block_spec is not None else {}), ) # Replace output_layer with ValueHead for critic models diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 1c24a549c6..cafca61c3b 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -353,6 +353,15 @@ def _ppo_update(self, data: dict[str, Any]) -> None: ) stats_tracker.scalar(**train_stat) + if "mtp_loss" in train_stat: + mtp_loss_val = train_stat["mtp_loss"] + logger.info( + f"[MTPTrain] MTP loss={mtp_loss_val:.6f}, " + f"scaling_factor={self.config.mtp_loss_scaling_factor}, " + f"scaled_loss={mtp_loss_val * self.config.mtp_loss_scaling_factor:.6f}" + ) + stats_tracker.scalar(mtp_loss=mtp_loss_val) + class PPOActorController(TrainController): def compute_logp(self, *args, **kwargs): diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 98c7c50ce6..d204998ba4 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -649,6 +649,14 @@ def _create_train_engine( else: actor = actor_cls(config=actor_config) actor.create_process_group(parallel_strategy=alloc.parallel) + + # Log MTP training configuration if enabled + if getattr(actor_config, "enable_mtp_training", False): + logger.info( + f"[MTPTrain] Actor engine created with MTP training enabled: " + f"mtp_num_layers={actor_config.mtp_num_layers}, " + f"mtp_loss_scaling_factor={actor_config.mtp_loss_scaling_factor}" + ) return actor def _create_critic( diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index 54dfcc63b4..db5e59ed7a 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -131,6 +131,38 @@ async def _collect_samples( stats_tracker.get(workflow_context.stat_scope()).scalar(reward=reward) + # Log speculative decoding statistics if available + if resp.spec_draft_token_num > 0: + accept_rate = ( + resp.spec_accept_token_num / resp.spec_draft_token_num + if resp.spec_draft_token_num > 0 + else 0.0 + ) + stats_tracker.get(workflow_context.stat_scope()).scalar( + spec_accept_rate=accept_rate, + spec_accept_tokens=float(resp.spec_accept_token_num), + spec_draft_tokens=float(resp.spec_draft_token_num), + ) + if accept_rate < 0.1: + logger.warning( + f"[SpecDec] Very low accept rate: {accept_rate:.4f} " + f"(accept={resp.spec_accept_token_num}, draft={resp.spec_draft_token_num}). " + f"Draft model may be severely out of sync with target model. " + f"Consider: 1) Reducing mtp_loss_scaling_factor; " + f"2) Checking MTP layer training status; " + f"3) Reducing speculative_num_steps." + ) + elif accept_rate < 0.5: + logger.info( + f"[SpecDec] Accept rate: {accept_rate:.4f} " + f"(accept={resp.spec_accept_token_num}, draft={resp.spec_draft_token_num})" + ) + else: + logger.info( + f"[SpecDec] Good accept rate: {accept_rate:.4f} " + f"(accept={resp.spec_accept_token_num}, draft={resp.spec_draft_token_num})" + ) + return resp, reward async def arun_episode( diff --git a/docs/en/_toc.yml b/docs/en/_toc.yml index 34638d9e7e..a8020b9d81 100644 --- a/docs/en/_toc.yml +++ b/docs/en/_toc.yml @@ -48,3 +48,4 @@ parts: - file: reference/rollout_workflow - file: reference/agent_workflow - file: reference/ai_assisted_dev + - file: tutorial/speculative_decoding diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index d51c58866c..0d467e891c 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -392,6 +392,10 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | | `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | | `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| `enable_mtp_training` | boolean | `False` | Enable MTP (Multi-Token Prediction) online training during RL. When enabled, MTP layers are trained alongside the main policy model to keep the draft model aligned with the evolving policy. | +| `mtp_num_layers` | integer | `1` | Number of MTP layers to train. Must match the model's MTP architecture. | +| `mtp_loss_scaling_factor` | float | `0.1` | Scaling factor for MTP auxiliary loss relative to the main RL loss. | +| `mtp_detach_heads` | boolean | `True` | Whether to detach hidden states before passing to MTP heads. When True (recommended for RL), MTP loss gradients only update MTP parameters, preventing the MTP auxiliary loss from corrupting the main policy gradients. When False, MTP loss gradients also flow back to the main model. | (section-ppo-critic)= @@ -533,62 +537,70 @@ Configuration for SGLang runtime. Refer to: https://github.com/sgl-project/sglang for detailed documentation. -| Parameter | Type | Default | Description | -| --------------------------------- | ----------------------- | ------------ | ----------- | -| `model_path` | string | `""` | - | -| `random_seed` | integer | `1` | - | -| `skip_tokenizer_init` | boolean | `False` | - | -| `disable_cuda_graph` | boolean | `False` | - | -| `disable_radix_cache` | boolean | `True` | - | -| `disable_cuda_graph_padding` | boolean | `False` | - | -| `enable_nccl_nvls` | boolean | `False` | - | -| `disable_outlines_disk_cache` | boolean | `False` | - | -| `disable_custom_all_reduce` | boolean | `False` | - | -| `disable_overlap_schedule` | boolean | `False` | - | -| `enable_mixed_chunk` | boolean | `False` | - | -| `enable_dp_attention` | boolean | `False` | - | -| `enable_ep_moe` | boolean | `False` | - | -| `enable_torch_compile` | boolean | `False` | - | -| `torch_compile_max_bs` | integer | `32` | - | -| `cuda_graph_max_bs` | integer \| None | `None` | - | -| `cuda_graph_bs` | list of integer \| None | `None` | - | -| `torchao_config` | string | `""` | - | -| `enable_nan_detection` | boolean | `False` | - | -| `enable_p2p_check` | boolean | `False` | - | -| `triton_attention_reduce_in_fp32` | boolean | `False` | - | -| `triton_attention_num_kv_splits` | integer | `8` | - | -| `num_continuous_decode_steps` | integer | `1` | - | -| `enable_memory_saver` | boolean | `False` | - | -| `allow_auto_truncate` | boolean | `False` | - | -| `attention_backend` | string \| None | `"fa3"` | - | -| `enable_multimodal` | boolean | `False` | - | -| `sampling_backend` | string \| None | `None` | - | -| `context_length` | integer \| None | `32768` | - | -| `mem_fraction_static` | float \| None | `0.9` | - | -| `max_running_requests` | integer \| None | `None` | - | -| `chunked_prefill_size` | integer \| None | `-1` | - | -| `max_prefill_tokens` | integer | `32768` | - | -| `schedule_policy` | string | `"lpm"` | - | -| `schedule_conservativeness` | float | `1.0` | - | -| `cpu_offload_gb` | integer | `0` | - | -| `dtype` | string | `"bfloat16"` | - | -| `kv_cache_dtype` | string | `"auto"` | - | -| `dp_size` | integer | `1` | - | -| `ep_size` | integer | `1` | - | -| `enable_lora` | boolean \| None | `None` | - | -| `max_lora_rank` | integer \| None | `None` | - | -| `max_loaded_loras` | integer | `8` | - | -| `lora_paths` | list of string \| None | `None` | - | -| `lora_backend` | string | `"triton"` | - | -| `log_level` | string | `"warning"` | - | -| `log_level_http` | string \| None | `"warning"` | - | -| `log_requests` | boolean | `False` | - | -| `log_requests_level` | integer | `0` | - | -| `show_time_cost` | boolean | `False` | - | -| `enable_metrics` | boolean | `True` | - | -| `decode_log_interval` | integer | `1` | - | -| `enable_multithread_load` | boolean | `False` | - | -| `enable_return_routed_experts` | boolean | `False` | - | +| Parameter | Type | Default | Description | +| --------------------------------- | ----------------------- | ------------ | -------------------------------------------------------------------------------------------------------------------------------------------------- | +| `model_path` | string | `""` | - | +| `random_seed` | integer | `1` | - | +| `skip_tokenizer_init` | boolean | `False` | - | +| `disable_cuda_graph` | boolean | `False` | - | +| `disable_radix_cache` | boolean | `True` | - | +| `disable_cuda_graph_padding` | boolean | `False` | - | +| `enable_nccl_nvls` | boolean | `False` | - | +| `disable_outlines_disk_cache` | boolean | `False` | - | +| `disable_custom_all_reduce` | boolean | `False` | - | +| `disable_overlap_schedule` | boolean | `False` | - | +| `enable_mixed_chunk` | boolean | `False` | - | +| `enable_dp_attention` | boolean | `False` | - | +| `enable_ep_moe` | boolean | `False` | - | +| `enable_torch_compile` | boolean | `False` | - | +| `torch_compile_max_bs` | integer | `32` | - | +| `cuda_graph_max_bs` | integer \| None | `None` | - | +| `cuda_graph_bs` | list of integer \| None | `None` | - | +| `torchao_config` | string | `""` | - | +| `enable_nan_detection` | boolean | `False` | - | +| `enable_p2p_check` | boolean | `False` | - | +| `triton_attention_reduce_in_fp32` | boolean | `False` | - | +| `triton_attention_num_kv_splits` | integer | `8` | - | +| `num_continuous_decode_steps` | integer | `1` | - | +| `enable_memory_saver` | boolean | `False` | - | +| `allow_auto_truncate` | boolean | `False` | - | +| `attention_backend` | string \| None | `"fa3"` | - | +| `enable_multimodal` | boolean | `False` | - | +| `sampling_backend` | string \| None | `None` | - | +| `context_length` | integer \| None | `32768` | - | +| `mem_fraction_static` | float \| None | `0.9` | - | +| `max_running_requests` | integer \| None | `None` | - | +| `chunked_prefill_size` | integer \| None | `-1` | - | +| `max_prefill_tokens` | integer | `32768` | - | +| `schedule_policy` | string | `"lpm"` | - | +| `schedule_conservativeness` | float | `1.0` | - | +| `cpu_offload_gb` | integer | `0` | - | +| `dtype` | string | `"bfloat16"` | - | +| `kv_cache_dtype` | string | `"auto"` | - | +| `dp_size` | integer | `1` | - | +| `ep_size` | integer | `1` | - | +| `enable_lora` | boolean \| None | `None` | - | +| `max_lora_rank` | integer \| None | `None` | - | +| `max_loaded_loras` | integer | `8` | - | +| `lora_paths` | list of string \| None | `None` | - | +| `lora_backend` | string | `"triton"` | - | +| `log_level` | string | `"warning"` | - | +| `log_level_http` | string \| None | `"warning"` | - | +| `log_requests` | boolean | `False` | - | +| `log_requests_level` | integer | `0` | - | +| `show_time_cost` | boolean | `False` | - | +| `enable_metrics` | boolean | `True` | - | +| `decode_log_interval` | integer | `1` | - | +| `enable_multithread_load` | boolean | `False` | - | +| `enable_return_routed_experts` | boolean | `False` | - | +| `speculative_algorithm` | string \| None | `None` | Speculative decoding algorithm. Options: 'EAGLE', 'EAGLE3'. None disables speculative decoding. | +| `speculative_draft_model_path` | string \| None | `None` | Path to the draft model for speculative decoding. | +| `speculative_num_steps` | integer | `3` | Number of speculative decoding draft steps. | +| `speculative_eagle_topk` | integer | `1` | Top-k value for EAGLE draft token selection. | +| `speculative_num_draft_tokens` | integer | `4` | Number of draft tokens per speculative step. | +| `speculative_attention_mode` | string \| None | `None` | Attention mode for speculative decoding. E.g., 'full', 'sparse'. | +| `enable_multi_layer_eagle` | boolean | `False` | - | +| `enable_draft_weights_cpu_backup` | boolean \| None | `None` | Keep draft model weights on CPU as backup during GPU offload cycles. Essential for colocated training+inference mode to prevent draft weight loss. | (section-v-llm)= @@ -941,6 +953,9 @@ Refer to Megatron-LM documentation for implementation details. | `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | | `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | | `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | +| `mtp_num_layers` | integer | `0` | Number of MTP (Multi-Token Prediction) layers for speculative decoding training. 0 means MTP is disabled. | +| `mtp_loss_scaling_factor` | float | `0.1` | Scaling factor for MTP auxiliary loss. Controls the weight of MTP loss relative to the main RL loss. | +| `mtp_detach_heads` | boolean | `True` | Whether to detach hidden states before passing to MTP heads in MegatronEngine. When True, MTP loss gradients only update MTP parameters. | (section-open-ai-proxy)= @@ -1106,5 +1121,9 @@ Configuration class: TeacherConfig | `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | | `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | | `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| `enable_mtp_training` | boolean | `False` | Enable MTP (Multi-Token Prediction) online training during RL. When enabled, MTP layers are trained alongside the main policy model to keep the draft model aligned with the evolving policy. | +| `mtp_num_layers` | integer | `1` | Number of MTP layers to train. Must match the model's MTP architecture. | +| `mtp_loss_scaling_factor` | float | `0.1` | Scaling factor for MTP auxiliary loss relative to the main RL loss. | +| `mtp_detach_heads` | boolean | `True` | Whether to detach hidden states before passing to MTP heads. When True (recommended for RL), MTP loss gradients only update MTP parameters, preventing the MTP auxiliary loss from corrupting the main policy gradients. When False, MTP loss gradients also flow back to the main model. | | `rl_loss_weight` | float | `1.0` | RL loss weight | | `distill_loss_weight` | float | `0.005` | Distillation loss weight | diff --git a/docs/en/tutorial/speculative_decoding.md b/docs/en/tutorial/speculative_decoding.md new file mode 100644 index 0000000000..4008e2b475 --- /dev/null +++ b/docs/en/tutorial/speculative_decoding.md @@ -0,0 +1,273 @@ +# Speculative Decoding with EAGLE + +## Overview + +Speculative decoding is a technique that accelerates autoregressive text generation by +using a lightweight **draft model** to propose multiple candidate tokens in parallel, +which the full **target model** then verifies in a single forward pass. When candidates +are accepted, the effective throughput increases significantly — often 2-3x — without +changing the output distribution. + +AReaL integrates **EAGLE** (Extrapolation Algorithm for Greater Language-model +Efficiency) as its speculative decoding backend. EAGLE uses the target model's hidden +states to predict future tokens through a small auxiliary head, making it particularly +well-suited for RL training loops where the policy model evolves continuously. + +### Why Speculative Decoding for RL Training? + +In RLHF / GRPO training pipelines, rollout generation is often the throughput +bottleneck. Speculative decoding directly addresses this by: + +- Reducing per-sample generation latency during rollout +- Increasing GPU utilization during the inference phase +- Maintaining identical output quality (the verification step is exact) + +When combined with **MTP (Multi-Token Prediction) online training**, the draft model +stays aligned with the evolving policy, preserving high accept rates throughout training. + +## Prerequisites + +Before enabling speculative decoding, ensure: + +1. **Model with MTP layers**: Your base model must include MTP (Multi-Token Prediction) + head layers. Models such as `Qwen/Qwen3-0.6B` and other Qwen3 variants ship with + MTP layers that can serve as EAGLE draft heads. + +2. **SGLang backend**: Speculative decoding requires the SGLang inference backend. + Ensure SGLang is installed and configured: + + ```bash + pip install "sglang[all]>=0.4.7" + ``` + +3. **Megatron-Core >= 0.12.0**: MTP online training requires Megatron-Core version + 0.12.0 or later, which includes the `MultiTokenPrediction` module with built-in + gradient isolation (embedding detach and functional_call for LM head). This ensures + MTP loss gradients only update MTP layer parameters without corrupting the main + policy model. + +4. **Sufficient GPU memory**: The draft model adds a small memory overhead on the + inference GPUs. Reduce `sglang.mem_fraction_static` if needed (e.g., from `0.85` to + `0.80`). + +## Configuration + +### SGLang EAGLE Configuration + +Speculative decoding is configured under the `sglang` section of your experiment YAML. +The key fields live in `SGLangConfig`: + +```yaml +sglang: + model_path: ${actor.path} + dtype: bfloat16 + mem_fraction_static: 0.80 + context_length: 32768 + + # --- Speculative Decoding --- + speculative_algorithm: "EAGLE" # or "EAGLE3" + speculative_draft_model_path: null # null = use built-in MTP heads + speculative_num_steps: 3 # number of draft steps per iteration + speculative_eagle_topk: 1 # top-k for draft token selection + speculative_num_draft_tokens: 4 # draft tokens proposed per step + speculative_attention_mode: null # null uses default attention +``` + +| Parameter | Default | Description | +|---|---|---| +| `speculative_algorithm` | `null` | Algorithm name: `"EAGLE"` or `"EAGLE3"`. `null` disables speculative decoding. | +| `speculative_draft_model_path` | `null` | Path to an external draft model. `null` reuses the target model's built-in MTP layers. | +| `speculative_num_steps` | `3` | How many autoregressive draft steps EAGLE performs before verification. | +| `speculative_eagle_topk` | `1` | Number of top-k candidates retained at each draft step. | +| `speculative_num_draft_tokens` | `4` | Total draft tokens fed to the verifier per speculative iteration. | +| `speculative_attention_mode` | `null` | Override attention kernel used during draft. `null` uses the engine default. | + +### MTP Online Training Configuration + +To keep the draft model aligned with the policy as it trains, enable MTP online +training in the `actor` section: + +```yaml +actor: + backend: "megatron:d4p1t1" + path: Qwen/Qwen3-0.6B + + # --- MTP Online Training --- + enable_mtp_training: true + mtp_num_layers: 1 # must match model's MTP architecture + mtp_loss_scaling_factor: 0.1 # weight of MTP loss vs. main RL loss + + # Megatron-specific MTP settings (in actor.megatron) + megatron: + mtp_num_layers: 1 # mirrors actor.mtp_num_layers + mtp_loss_scaling_factor: 0.1 # mirrors actor.mtp_loss_scaling_factor +``` + +| Parameter | Default | Description | +|---|---|---| +| `enable_mtp_training` | `false` | Master switch for MTP online training. | +| `mtp_num_layers` | `1` | Number of MTP head layers to train. Must be > 0 when enabled. | +| `mtp_loss_scaling_factor` | `0.1` | Weight of the MTP auxiliary loss. Must be in (0, 1.0]. | + +When `enable_mtp_training` is `true`, the trainer computes an auxiliary next-token +prediction loss on the MTP heads and adds it (scaled) to the main RL objective. This +ensures the draft heads continuously improve their prediction accuracy as the policy +changes. + +## Full Example + +Below is a minimal GRPO + EAGLE configuration for GSM8K with 4 GPUs: + +```yaml +experiment_name: gsm8k-grpo-eagle +trial_name: trial0 +seed: 42 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 4 + +actor: + backend: "megatron:d2p1t1" + path: Qwen/Qwen3-0.6B + enable_mtp_training: true + mtp_num_layers: 1 + mtp_loss_scaling_factor: 0.1 + +sglang: + model_path: ${actor.path} + speculative_algorithm: "EAGLE" + speculative_num_steps: 3 + speculative_num_draft_tokens: 4 + mem_fraction_static: 0.80 + +train_dataset: + path: openai/gsm8k + type: rl + batch_size: 128 +``` + +For the complete configuration file, see +[`examples/math/gsm8k_grpo_megatron_eagle.yaml`](https://github.com/inclusionAI/AReaL/blob/main/examples/math/gsm8k_grpo_megatron_eagle.yaml). + +## Monitoring + +### Key Metrics + +During training, watch the following metrics in your logs or WandB dashboard: + +1. **Speculative Accept Rate** + - Logged as `spec_accept_rate` (= `spec_accept_token_num / spec_draft_token_num`) + - A healthy accept rate is **0.6 - 0.9** for well-aligned draft models + - If accept rate drops below **0.4**, the draft model is falling behind the policy + +2. **MTP Loss** + - Logged as `mtp_loss` in training statistics + - Should decrease over time; a rising MTP loss indicates training instability + - Typical range: **0.5 - 2.0** depending on model size and task + +3. **Generation Throughput** + - Compare tokens/second with and without speculative decoding + - Expected speedup: **1.5x - 3x** depending on accept rate and model architecture + +### Interpreting Accept Rate Trends + +| Trend | Meaning | Action | +|---|---|---| +| Stable 0.7+ | Draft model is well-aligned | No action needed | +| Gradual decline | Policy is evolving faster than draft | Increase `mtp_loss_scaling_factor` | +| Sudden drop | Possible learning rate spike or data shift | Check training stability | +| Very low (<0.3) | Draft model is ineffective | Verify MTP layers are being trained | + +## Troubleshooting + +### Accept Rate is Very Low + +1. **Verify MTP training is enabled**: Check that `actor.enable_mtp_training: true` is + set. Without online training, the draft model will quickly become stale. + +2. **Check MTP layer count**: Ensure `actor.mtp_num_layers` matches your model's + architecture. Qwen3 models typically have 1 MTP layer. + +3. **Increase MTP loss weight**: If the accept rate degrades over time, try increasing + `mtp_loss_scaling_factor` from `0.1` to `0.2` or `0.3`. + +### Out of Memory (OOM) During Inference + +1. **Reduce memory fraction**: Lower `sglang.mem_fraction_static` (e.g., `0.75`). + +2. **Reduce draft tokens**: Lower `speculative_num_draft_tokens` from `4` to `2`. + +3. **Reduce draft steps**: Lower `speculative_num_steps` from `3` to `2`. + +### Training is Slower Than Expected + +1. **Check GPU allocation**: Ensure inference and training GPUs are properly separated. + Use `sglang:d2p1t1` with `megatron:d2p1t1` on 4 GPUs for balanced allocation. + +2. **Profile the pipeline**: Enable `perf_tracer.enabled: true` to identify whether + the bottleneck is in generation, training, or data loading. + +3. **Disable speculative decoding temporarily**: Set `speculative_algorithm: null` and + compare throughput to isolate whether the overhead is from speculation itself. + +### MTP Loss is Not Decreasing + +1. **Verify model supports MTP**: Not all model architectures include MTP heads. Check + that the model's config includes MTP layer definitions. + +2. **Check learning rate**: The MTP heads share the actor's optimizer. If the base + learning rate is too low, MTP training may stagnate. + +3. **Inspect gradient flow**: Ensure `actor.gradient_checkpointing` is not interfering + with MTP gradient computation. + +## Advanced Configuration + +### Using an External Draft Model + +Instead of relying on built-in MTP layers, you can provide a separate draft model: + +```yaml +sglang: + speculative_algorithm: "EAGLE" + speculative_draft_model_path: /path/to/eagle-draft-model +``` + +Note that when using an external draft model, `enable_mtp_training` should typically be +`false` unless the external model's weights are also updated during training. + +### EAGLE3 Algorithm + +EAGLE3 is an improved variant that supports more flexible tree-structured speculation: + +```yaml +sglang: + speculative_algorithm: "EAGLE3" + speculative_num_steps: 5 + speculative_eagle_topk: 2 + speculative_num_draft_tokens: 8 +``` + +EAGLE3 generally achieves higher accept rates but uses more memory for the expanded +draft tree. + +### Draft Weight CPU Backup + +When using colocated training and inference (i.e., the same GPUs serve both), draft +model weights may be lost during GPU memory reclamation. Enable CPU backup: + +```yaml +sglang: + enable_draft_weights_cpu_backup: true +``` + +This keeps a CPU copy of draft weights that is restored after each training step. + +## References + +- [EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty](https://arxiv.org/abs/2401.15077) +- [SGLang Documentation](https://sgl-project.github.io/) +- [AReaL Megatron Backend Tutorial](megatron.md) +- [AReaL Allocation Mode Reference](../reference/alloc_mode.md) diff --git a/docs/zh/_toc.yml b/docs/zh/_toc.yml index 1a33f9849e..4e70f95c63 100644 --- a/docs/zh/_toc.yml +++ b/docs/zh/_toc.yml @@ -48,3 +48,4 @@ parts: - file: reference/rollout_workflow - file: reference/agent_workflow - file: reference/ai_assisted_dev + - file: tutorial/speculative_decoding diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index afd64db4af..86cb648454 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -390,6 +390,10 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | | `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | | `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| `enable_mtp_training` | boolean | `False` | Enable MTP (Multi-Token Prediction) online training during RL. When enabled, MTP layers are trained alongside the main policy model to keep the draft model aligned with the evolving policy. | +| `mtp_num_layers` | integer | `1` | Number of MTP layers to train. Must match the model's MTP architecture. | +| `mtp_loss_scaling_factor` | float | `0.1` | Scaling factor for MTP auxiliary loss relative to the main RL loss. | +| `mtp_detach_heads` | boolean | `True` | Whether to detach hidden states before passing to MTP heads. When True (recommended for RL), MTP loss gradients only update MTP parameters, preventing the MTP auxiliary loss from corrupting the main policy gradients. When False, MTP loss gradients also flow back to the main model. | (section-ppo-critic)= @@ -531,62 +535,70 @@ Configuration for SGLang runtime. Refer to: https://github.com/sgl-project/sglang for detailed documentation. -| Parameter | Type | Default | Description | -| --------------------------------- | ----------------------- | ------------ | ----------- | -| `model_path` | string | `""` | - | -| `random_seed` | integer | `1` | - | -| `skip_tokenizer_init` | boolean | `False` | - | -| `disable_cuda_graph` | boolean | `False` | - | -| `disable_radix_cache` | boolean | `True` | - | -| `disable_cuda_graph_padding` | boolean | `False` | - | -| `enable_nccl_nvls` | boolean | `False` | - | -| `disable_outlines_disk_cache` | boolean | `False` | - | -| `disable_custom_all_reduce` | boolean | `False` | - | -| `disable_overlap_schedule` | boolean | `False` | - | -| `enable_mixed_chunk` | boolean | `False` | - | -| `enable_dp_attention` | boolean | `False` | - | -| `enable_ep_moe` | boolean | `False` | - | -| `enable_torch_compile` | boolean | `False` | - | -| `torch_compile_max_bs` | integer | `32` | - | -| `cuda_graph_max_bs` | integer \| None | `None` | - | -| `cuda_graph_bs` | list of integer \| None | `None` | - | -| `torchao_config` | string | `""` | - | -| `enable_nan_detection` | boolean | `False` | - | -| `enable_p2p_check` | boolean | `False` | - | -| `triton_attention_reduce_in_fp32` | boolean | `False` | - | -| `triton_attention_num_kv_splits` | integer | `8` | - | -| `num_continuous_decode_steps` | integer | `1` | - | -| `enable_memory_saver` | boolean | `False` | - | -| `allow_auto_truncate` | boolean | `False` | - | -| `attention_backend` | string \| None | `"fa3"` | - | -| `enable_multimodal` | boolean | `False` | - | -| `sampling_backend` | string \| None | `None` | - | -| `context_length` | integer \| None | `32768` | - | -| `mem_fraction_static` | float \| None | `0.9` | - | -| `max_running_requests` | integer \| None | `None` | - | -| `chunked_prefill_size` | integer \| None | `-1` | - | -| `max_prefill_tokens` | integer | `32768` | - | -| `schedule_policy` | string | `"lpm"` | - | -| `schedule_conservativeness` | float | `1.0` | - | -| `cpu_offload_gb` | integer | `0` | - | -| `dtype` | string | `"bfloat16"` | - | -| `kv_cache_dtype` | string | `"auto"` | - | -| `dp_size` | integer | `1` | - | -| `ep_size` | integer | `1` | - | -| `enable_lora` | boolean \| None | `None` | - | -| `max_lora_rank` | integer \| None | `None` | - | -| `max_loaded_loras` | integer | `8` | - | -| `lora_paths` | list of string \| None | `None` | - | -| `lora_backend` | string | `"triton"` | - | -| `log_level` | string | `"warning"` | - | -| `log_level_http` | string \| None | `"warning"` | - | -| `log_requests` | boolean | `False` | - | -| `log_requests_level` | integer | `0` | - | -| `show_time_cost` | boolean | `False` | - | -| `enable_metrics` | boolean | `True` | - | -| `decode_log_interval` | integer | `1` | - | -| `enable_multithread_load` | boolean | `False` | - | -| `enable_return_routed_experts` | boolean | `False` | - | +| Parameter | Type | Default | Description | +| --------------------------------- | ----------------------- | ------------ | -------------------------------------------------------------------------------------------------------------------------------------------------- | +| `model_path` | string | `""` | - | +| `random_seed` | integer | `1` | - | +| `skip_tokenizer_init` | boolean | `False` | - | +| `disable_cuda_graph` | boolean | `False` | - | +| `disable_radix_cache` | boolean | `True` | - | +| `disable_cuda_graph_padding` | boolean | `False` | - | +| `enable_nccl_nvls` | boolean | `False` | - | +| `disable_outlines_disk_cache` | boolean | `False` | - | +| `disable_custom_all_reduce` | boolean | `False` | - | +| `disable_overlap_schedule` | boolean | `False` | - | +| `enable_mixed_chunk` | boolean | `False` | - | +| `enable_dp_attention` | boolean | `False` | - | +| `enable_ep_moe` | boolean | `False` | - | +| `enable_torch_compile` | boolean | `False` | - | +| `torch_compile_max_bs` | integer | `32` | - | +| `cuda_graph_max_bs` | integer \| None | `None` | - | +| `cuda_graph_bs` | list of integer \| None | `None` | - | +| `torchao_config` | string | `""` | - | +| `enable_nan_detection` | boolean | `False` | - | +| `enable_p2p_check` | boolean | `False` | - | +| `triton_attention_reduce_in_fp32` | boolean | `False` | - | +| `triton_attention_num_kv_splits` | integer | `8` | - | +| `num_continuous_decode_steps` | integer | `1` | - | +| `enable_memory_saver` | boolean | `False` | - | +| `allow_auto_truncate` | boolean | `False` | - | +| `attention_backend` | string \| None | `"fa3"` | - | +| `enable_multimodal` | boolean | `False` | - | +| `sampling_backend` | string \| None | `None` | - | +| `context_length` | integer \| None | `32768` | - | +| `mem_fraction_static` | float \| None | `0.9` | - | +| `max_running_requests` | integer \| None | `None` | - | +| `chunked_prefill_size` | integer \| None | `-1` | - | +| `max_prefill_tokens` | integer | `32768` | - | +| `schedule_policy` | string | `"lpm"` | - | +| `schedule_conservativeness` | float | `1.0` | - | +| `cpu_offload_gb` | integer | `0` | - | +| `dtype` | string | `"bfloat16"` | - | +| `kv_cache_dtype` | string | `"auto"` | - | +| `dp_size` | integer | `1` | - | +| `ep_size` | integer | `1` | - | +| `enable_lora` | boolean \| None | `None` | - | +| `max_lora_rank` | integer \| None | `None` | - | +| `max_loaded_loras` | integer | `8` | - | +| `lora_paths` | list of string \| None | `None` | - | +| `lora_backend` | string | `"triton"` | - | +| `log_level` | string | `"warning"` | - | +| `log_level_http` | string \| None | `"warning"` | - | +| `log_requests` | boolean | `False` | - | +| `log_requests_level` | integer | `0` | - | +| `show_time_cost` | boolean | `False` | - | +| `enable_metrics` | boolean | `True` | - | +| `decode_log_interval` | integer | `1` | - | +| `enable_multithread_load` | boolean | `False` | - | +| `enable_return_routed_experts` | boolean | `False` | - | +| `speculative_algorithm` | string \| None | `None` | Speculative decoding algorithm. Options: 'EAGLE', 'EAGLE3'. None disables speculative decoding. | +| `speculative_draft_model_path` | string \| None | `None` | Path to the draft model for speculative decoding. | +| `speculative_num_steps` | integer | `3` | Number of speculative decoding draft steps. | +| `speculative_eagle_topk` | integer | `1` | Top-k value for EAGLE draft token selection. | +| `speculative_num_draft_tokens` | integer | `4` | Number of draft tokens per speculative step. | +| `speculative_attention_mode` | string \| None | `None` | Attention mode for speculative decoding. E.g., 'full', 'sparse'. | +| `enable_multi_layer_eagle` | boolean | `False` | - | +| `enable_draft_weights_cpu_backup` | boolean \| None | `None` | Keep draft model weights on CPU as backup during GPU offload cycles. Essential for colocated training+inference mode to prevent draft weight loss. | (section-v-llm)= @@ -939,6 +951,9 @@ Refer to Megatron-LM documentation for implementation details. | `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | | `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | | `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | +| `mtp_num_layers` | integer | `0` | Number of MTP (Multi-Token Prediction) layers for speculative decoding training. 0 means MTP is disabled. | +| `mtp_loss_scaling_factor` | float | `0.1` | Scaling factor for MTP auxiliary loss. Controls the weight of MTP loss relative to the main RL loss. | +| `mtp_detach_heads` | boolean | `True` | Whether to detach hidden states before passing to MTP heads in MegatronEngine. When True, MTP loss gradients only update MTP parameters. | (section-open-ai-proxy)= @@ -1104,5 +1119,9 @@ Configuration class: TeacherConfig | `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | | `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | | `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| `enable_mtp_training` | boolean | `False` | Enable MTP (Multi-Token Prediction) online training during RL. When enabled, MTP layers are trained alongside the main policy model to keep the draft model aligned with the evolving policy. | +| `mtp_num_layers` | integer | `1` | Number of MTP layers to train. Must match the model's MTP architecture. | +| `mtp_loss_scaling_factor` | float | `0.1` | Scaling factor for MTP auxiliary loss relative to the main RL loss. | +| `mtp_detach_heads` | boolean | `True` | Whether to detach hidden states before passing to MTP heads. When True (recommended for RL), MTP loss gradients only update MTP parameters, preventing the MTP auxiliary loss from corrupting the main policy gradients. When False, MTP loss gradients also flow back to the main model. | | `rl_loss_weight` | float | `1.0` | RL loss weight | | `distill_loss_weight` | float | `0.005` | Distillation loss weight | diff --git a/docs/zh/tutorial/speculative_decoding.md b/docs/zh/tutorial/speculative_decoding.md new file mode 100644 index 0000000000..d8736d7f5c --- /dev/null +++ b/docs/zh/tutorial/speculative_decoding.md @@ -0,0 +1,262 @@ +# 使用 EAGLE 进行推测解码 + +## 概述 + +推测解码(Speculative Decoding)是一种加速自回归文本生成的技术。它使用一个轻量级的 +**草稿模型(Draft Model)**并行提出多个候选 token,然后由完整的**目标模型(Target Model)** +在一次前向传播中进行验证。当候选 token 被接受时,有效吞吐量可显著提升(通常 2-3 倍), +且不改变输出分布。 + +AReaL 集成了 **EAGLE**(Extrapolation Algorithm for Greater Language-model Efficiency) +作为推测解码后端。EAGLE 利用目标模型的隐藏状态通过小型辅助头预测未来 token,特别适合 +RL 训练循环中策略模型持续演化的场景。 + +### 为什么在 RL 训练中使用推测解码? + +在 RLHF / GRPO 训练流水线中,rollout 生成通常是吞吐瓶颈。推测解码通过以下方式直接 +解决这一问题: + +- 降低 rollout 阶段每个样本的生成延迟 +- 提高推理阶段的 GPU 利用率 +- 保持完全一致的输出质量(验证步骤是精确的) + +结合 **MTP(多 Token 预测)在线训练**,草稿模型能与不断演化的策略保持对齐,在整个 +训练过程中维持较高的接受率。 + +## 前提条件 + +启用推测解码前,请确保: + +1. **带 MTP 层的模型**:基座模型必须包含 MTP(多 Token 预测)头层。`Qwen/Qwen3-0.6B` + 等 Qwen3 系列模型自带 MTP 层,可作为 EAGLE 草稿头使用。 + +2. **SGLang 后端**:推测解码需要 SGLang 推理后端。请确保已安装并配置 SGLang: + + ```bash + pip install "sglang[all]>=0.4.7" + ``` + +3. **Megatron-Core >= 0.12.0**:MTP 在线训练需要 Megatron-Core 0.12.0 或更高版本, + 该版本包含了内置梯度隔离(embedding detach 和 functional_call lm_head)的 + `MultiTokenPrediction` 模块。这确保 MTP 损失梯度仅更新 MTP 层参数,不会污染 + 主策略模型的权重。 + +4. **充足的 GPU 显存**:草稿模型会在推理 GPU 上增加少量显存开销。如需要,可降低 + `sglang.mem_fraction_static`(例如从 `0.85` 降至 `0.80`)。 + +## 配置说明 + +### SGLang EAGLE 配置 + +推测解码在实验 YAML 的 `sglang` 部分进行配置。关键字段位于 `SGLangConfig` 中: + +```yaml +sglang: + model_path: ${actor.path} + dtype: bfloat16 + mem_fraction_static: 0.80 + context_length: 32768 + + # --- 推测解码配置 --- + speculative_algorithm: "EAGLE" # 或 "EAGLE3" + speculative_draft_model_path: null # null = 使用内置 MTP 头 + speculative_num_steps: 3 # 每次迭代的草稿步数 + speculative_eagle_topk: 1 # 草稿 token 选择的 top-k 值 + speculative_num_draft_tokens: 4 # 每步提出的草稿 token 数 + speculative_attention_mode: null # null 使用默认注意力机制 +``` + +| 参数 | 默认值 | 说明 | +|---|---|---| +| `speculative_algorithm` | `null` | 算法名称:`"EAGLE"` 或 `"EAGLE3"`。`null` 禁用推测解码。 | +| `speculative_draft_model_path` | `null` | 外部草稿模型路径。`null` 复用目标模型内置的 MTP 层。 | +| `speculative_num_steps` | `3` | EAGLE 在验证前执行的自回归草稿步数。 | +| `speculative_eagle_topk` | `1` | 每个草稿步保留的 top-k 候选数。 | +| `speculative_num_draft_tokens` | `4` | 每次推测迭代中馈入验证器的总草稿 token 数。 | +| `speculative_attention_mode` | `null` | 覆盖草稿阶段使用的注意力核。`null` 使用引擎默认值。 | + +### MTP 在线训练配置 + +为保持草稿模型与训练中的策略对齐,请在 `actor` 部分启用 MTP 在线训练: + +```yaml +actor: + backend: "megatron:d4p1t1" + path: Qwen/Qwen3-0.6B + + # --- MTP 在线训练 --- + enable_mtp_training: true + mtp_num_layers: 1 # 必须匹配模型的 MTP 架构 + mtp_loss_scaling_factor: 0.1 # MTP 损失相对于主 RL 损失的权重 + + # Megatron 特定的 MTP 设置(在 actor.megatron 中) + megatron: + mtp_num_layers: 1 # 与 actor.mtp_num_layers 一致 + mtp_loss_scaling_factor: 0.1 # 与 actor.mtp_loss_scaling_factor 一致 +``` + +| 参数 | 默认值 | 说明 | +|---|---|---| +| `enable_mtp_training` | `false` | MTP 在线训练的总开关。 | +| `mtp_num_layers` | `1` | 训练的 MTP 头层数。启用时必须 > 0。 | +| `mtp_loss_scaling_factor` | `0.1` | MTP 辅助损失的权重。必须在 (0, 1.0] 范围内。 | + +当 `enable_mtp_training` 为 `true` 时,训练器会在 MTP 头上计算辅助的下一 token +预测损失,并将其(按比例缩放后)加到主 RL 目标中。这确保了草稿头随策略变化持续 +提升预测准确性。 + +## 完整示例 + +以下是一个使用 4 GPU 的最小 GRPO + EAGLE GSM8K 配置: + +```yaml +experiment_name: gsm8k-grpo-eagle +trial_name: trial0 +seed: 42 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 4 + +actor: + backend: "megatron:d2p1t1" + path: Qwen/Qwen3-0.6B + enable_mtp_training: true + mtp_num_layers: 1 + mtp_loss_scaling_factor: 0.1 + +sglang: + model_path: ${actor.path} + speculative_algorithm: "EAGLE" + speculative_num_steps: 3 + speculative_num_draft_tokens: 4 + mem_fraction_static: 0.80 + +train_dataset: + path: openai/gsm8k + type: rl + batch_size: 128 +``` + +完整配置文件请参见 +[`examples/math/gsm8k_grpo_megatron_eagle.yaml`](https://github.com/inclusionAI/AReaL/blob/main/examples/math/gsm8k_grpo_megatron_eagle.yaml)。 + +## 监控 + +### 关键指标 + +训练过程中,请在日志或 WandB 面板中关注以下指标: + +1. **推测接受率(Speculative Accept Rate)** + - 日志中记录为 `spec_accept_rate`(= `spec_accept_token_num / spec_draft_token_num`) + - 对齐良好的草稿模型的健康接受率为 **0.6 - 0.9** + - 如果接受率降至 **0.4** 以下,说明草稿模型正在落后于策略 + +2. **MTP 损失(MTP Loss)** + - 训练统计中记录为 `mtp_loss` + - 应随时间下降;MTP 损失上升表明训练不稳定 + - 典型范围:**0.5 - 2.0**,取决于模型大小和任务 + +3. **生成吞吐量(Generation Throughput)** + - 对比启用和禁用推测解码时的 tokens/秒 + - 预期加速比:**1.5x - 3x**,取决于接受率和模型架构 + +### 接受率趋势解读 + +| 趋势 | 含义 | 建议操作 | +|---|---|---| +| 稳定在 0.7 以上 | 草稿模型对齐良好 | 无需操作 | +| 逐渐下降 | 策略演化速度快于草稿模型 | 增大 `mtp_loss_scaling_factor` | +| 突然下降 | 可能是学习率突变或数据分布变化 | 检查训练稳定性 | +| 极低(<0.3) | 草稿模型无效 | 验证 MTP 层是否在训练 | + +## 故障排除 + +### 接受率很低 + +1. **验证 MTP 训练已启用**:检查是否设置了 `actor.enable_mtp_training: true`。 + 未启用在线训练时,草稿模型会很快过时。 + +2. **检查 MTP 层数**:确保 `actor.mtp_num_layers` 与模型架构匹配。Qwen3 模型 + 通常有 1 个 MTP 层。 + +3. **增大 MTP 损失权重**:如果接受率随时间下降,尝试将 `mtp_loss_scaling_factor` + 从 `0.1` 增加到 `0.2` 或 `0.3`。 + +### 推理阶段显存不足(OOM) + +1. **降低显存比例**:将 `sglang.mem_fraction_static` 调低(例如 `0.75`)。 + +2. **减少草稿 token 数**:将 `speculative_num_draft_tokens` 从 `4` 降至 `2`。 + +3. **减少草稿步数**:将 `speculative_num_steps` 从 `3` 降至 `2`。 + +### 训练速度低于预期 + +1. **检查 GPU 分配**:确保推理和训练 GPU 正确分离。在 4 GPU 上可使用 + `sglang:d2p1t1` 配合 `megatron:d2p1t1` 以实现均衡分配。 + +2. **分析流水线**:启用 `perf_tracer.enabled: true` 以识别瓶颈是在生成、训练 + 还是数据加载阶段。 + +3. **临时禁用推测解码**:设置 `speculative_algorithm: null` 并对比吞吐量,以 + 判断开销是否来自推测本身。 + +### MTP 损失不下降 + +1. **验证模型支持 MTP**:并非所有模型架构都包含 MTP 头。检查模型配置中是否包含 + MTP 层定义。 + +2. **检查学习率**:MTP 头与 actor 共享优化器。如果基础学习率过低,MTP 训练可能 + 停滞。 + +3. **检查梯度流**:确保 `actor.gradient_checkpointing` 未影响 MTP 梯度计算。 + +## 高级配置 + +### 使用外部草稿模型 + +除了依赖内置 MTP 层,您也可以提供独立的草稿模型: + +```yaml +sglang: + speculative_algorithm: "EAGLE" + speculative_draft_model_path: /path/to/eagle-draft-model +``` + +注意:使用外部草稿模型时,通常应将 `enable_mtp_training` 设为 `false`,除非外部 +模型的权重也在训练中更新。 + +### EAGLE3 算法 + +EAGLE3 是一种改进变体,支持更灵活的树形结构推测: + +```yaml +sglang: + speculative_algorithm: "EAGLE3" + speculative_num_steps: 5 + speculative_eagle_topk: 2 + speculative_num_draft_tokens: 8 +``` + +EAGLE3 通常能达到更高的接受率,但扩展的草稿树会消耗更多显存。 + +### 草稿权重 CPU 备份 + +当使用共置的训练和推理模式(即相同 GPU 同时服务两者)时,草稿模型权重可能在 GPU +显存回收时丢失。启用 CPU 备份: + +```yaml +sglang: + enable_draft_weights_cpu_backup: true +``` + +这会保留草稿权重的 CPU 副本,在每个训练步之后恢复。 + +## 参考资料 + +- [EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty](https://arxiv.org/abs/2401.15077) +- [SGLang 文档](https://sgl-project.github.io/) +- [AReaL Megatron 后端教程](megatron.md) +- [AReaL 分配模式参考](../reference/alloc_mode.md) diff --git a/examples/math/gsm8k_grpo_megatron_eagle.yaml b/examples/math/gsm8k_grpo_megatron_eagle.yaml new file mode 100644 index 0000000000..7aa2603564 --- /dev/null +++ b/examples/math/gsm8k_grpo_megatron_eagle.yaml @@ -0,0 +1,200 @@ +# GSM8K GRPO with EAGLE Speculative Decoding +# This example demonstrates GRPO training with EAGLE speculative decoding +# and MTP online training on 4 GPUs using the Megatron backend. +# +# Usage: +# areal run examples/math/gsm8k_grpo_megatron_eagle.yaml +# +# Requirements: +# - 4 GPUs (2 for training, 2 for inference) +# - Model with MTP layers (e.g., Qwen3 series) +# - SGLang >= 0.4.7 + +experiment_name: gsm8k-grpo-megatron-eagle +trial_name: trial0 + +seed: 42 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 4 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +scheduler: + type: null + +rollout: + backend: "sglang:d2p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 128 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:d2p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-0.6B + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 3e-6 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + + # MTP Online Training - keeps EAGLE draft heads aligned with evolving policy + enable_mtp_training: true + mtp_num_layers: 1 + mtp_loss_scaling_factor: 0.1 + + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + + megatron: + mtp_num_layers: 1 + mtp_loss_scaling_factor: 0.1 + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang with EAGLE Speculative Decoding +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.80 + + # EAGLE speculative decoding settings + speculative_algorithm: "EAGLE" + speculative_draft_model_path: null + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + speculative_attention_mode: null + enable_draft_weights_cpu_backup: true + +# Datasets +train_dataset: + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 128 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/tests/speculative_decoding/__init__.py b/tests/speculative_decoding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/speculative_decoding/config_spec_only.yaml b/tests/speculative_decoding/config_spec_only.yaml new file mode 100644 index 0000000000..e50eb0ee8c --- /dev/null +++ b/tests/speculative_decoding/config_spec_only.yaml @@ -0,0 +1,170 @@ +# Test config: EAGLE speculative decoding only (no MTP training) +# Used by tests/speculative_decoding/test_speculative_decoding.py + +experiment_name: test-spec-decode-only +trial_name: trial0 + +seed: 42 +enable_offload: false +total_train_epochs: 1 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 2 + fileroot: /tmp/areal/test_spec_decode + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/test_spec_decode/name_resolve + +scheduler: + type: null + +rollout: + backend: "sglang:d1p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 16 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: false + +gconfig: + n_samples: 2 + min_new_tokens: 0 + max_new_tokens: 128 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:d1p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-0.6B + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 4096 + optimizer: + type: adam + lr: 1e-5 + weight_decay: 0.0 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.0 + eps_clip: 0.2 + temperature: ${gconfig.temperature} + reward_scaling: 1.0 + reward_bias: 0.0 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + max_new_tokens: ${gconfig.max_new_tokens} + + # MTP training DISABLED - only speculative decoding for inference + enable_mtp_training: false + + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 16 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 4096 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang with EAGLE (speculative decoding only, no MTP training) +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + context_length: 4096 + mem_fraction_static: 0.80 + + speculative_algorithm: "EAGLE" + speculative_draft_model_path: null + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + +train_dataset: + batch_size: 16 + shuffle: true + pin_memory: true + num_workers: 2 + path: openai/gsm8k + type: rl + max_length: 256 + +valid_dataset: + batch_size: 16 + pin_memory: true + num_workers: 2 + path: openai/gsm8k + type: rl + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/tests/speculative_decoding/config_spec_with_mtp.yaml b/tests/speculative_decoding/config_spec_with_mtp.yaml new file mode 100644 index 0000000000..e6e4a02b77 --- /dev/null +++ b/tests/speculative_decoding/config_spec_with_mtp.yaml @@ -0,0 +1,177 @@ +# Test config: EAGLE speculative decoding + MTP online training +# Used by tests/speculative_decoding/test_speculative_decoding.py + +experiment_name: test-spec-decode-mtp +trial_name: trial0 + +seed: 42 +enable_offload: false +total_train_epochs: 1 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 2 + fileroot: /tmp/areal/test_spec_mtp + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/test_spec_mtp/name_resolve + +scheduler: + type: null + +rollout: + backend: "sglang:d1p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 16 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: false + +gconfig: + n_samples: 2 + min_new_tokens: 0 + max_new_tokens: 128 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:d1p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-0.6B + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 4096 + optimizer: + type: adam + lr: 1e-5 + weight_decay: 0.0 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.0 + eps_clip: 0.2 + temperature: ${gconfig.temperature} + reward_scaling: 1.0 + reward_bias: 0.0 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + max_new_tokens: ${gconfig.max_new_tokens} + + # MTP Online Training ENABLED + enable_mtp_training: true + mtp_num_layers: 1 + mtp_loss_scaling_factor: 0.1 + + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 16 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + + megatron: + mtp_num_layers: 1 + mtp_loss_scaling_factor: 0.1 + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 4096 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang with EAGLE + MTP training +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + context_length: 4096 + mem_fraction_static: 0.80 + + speculative_algorithm: "EAGLE" + speculative_draft_model_path: null + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + enable_draft_weights_cpu_backup: true + +train_dataset: + batch_size: 16 + shuffle: true + pin_memory: true + num_workers: 2 + path: openai/gsm8k + type: rl + max_length: 256 + +valid_dataset: + batch_size: 16 + pin_memory: true + num_workers: 2 + path: openai/gsm8k + type: rl + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/tests/speculative_decoding/entrypoint.py b/tests/speculative_decoding/entrypoint.py new file mode 100644 index 0000000000..4c14d2faa2 --- /dev/null +++ b/tests/speculative_decoding/entrypoint.py @@ -0,0 +1,253 @@ +"""E2E test entrypoint for speculative decoding with EAGLE. + +This module provides a MinimalSpecDecodePPOTrainer that wraps the standard +PPOTrainer to collect and validate speculative decoding statistics (accept +rate, draft tokens) and MTP training loss during end-to-end test runs. + +Usage: + python -m tests.speculative_decoding.entrypoint --config +""" + +import argparse +import logging +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class SpecDecodeStats: + """Accumulated speculative decoding statistics across training steps.""" + + total_accept_tokens: int = 0 + total_draft_tokens: int = 0 + step_accept_rates: List[float] = field(default_factory=list) + mtp_losses: List[float] = field(default_factory=list) + rewards: List[float] = field(default_factory=list) + + @property + def overall_accept_rate(self) -> float: + """Compute overall accept rate across all steps.""" + if self.total_draft_tokens == 0: + return 0.0 + return self.total_accept_tokens / self.total_draft_tokens + + @property + def mean_mtp_loss(self) -> float: + """Compute mean MTP loss across all steps.""" + if not self.mtp_losses: + return float("nan") + return sum(self.mtp_losses) / len(self.mtp_losses) + + @property + def mean_reward(self) -> float: + """Compute mean reward across all steps.""" + if not self.rewards: + return float("nan") + return sum(self.rewards) / len(self.rewards) + + def summary(self) -> Dict[str, Any]: + """Return a summary dict of all collected statistics.""" + return { + "total_accept_tokens": self.total_accept_tokens, + "total_draft_tokens": self.total_draft_tokens, + "overall_accept_rate": self.overall_accept_rate, + "num_steps": len(self.step_accept_rates), + "step_accept_rates": self.step_accept_rates, + "mean_mtp_loss": self.mean_mtp_loss, + "mean_reward": self.mean_reward, + "mtp_losses": self.mtp_losses, + "rewards": self.rewards, + } + + +class MinimalSpecDecodePPOTrainer: + """A minimal wrapper around PPOTrainer for speculative decoding E2E tests. + + This trainer intercepts training statistics to collect and validate + speculative decoding metrics including: + - Speculative accept rate (spec_accept_token_num / spec_draft_token_num) + - MTP auxiliary loss (when enable_mtp_training is True) + - Reward statistics + + It is designed for use in integration tests, not production training. + """ + + def __init__(self, config_path: str): + """Initialize the trainer with a config file path. + + Args: + config_path: Path to the experiment YAML configuration file. + """ + self.config_path = config_path + self.stats = SpecDecodeStats() + self._trainer = None + self._config = None + + def _load_config(self) -> Dict[str, Any]: + """Load and parse the YAML configuration file. + + Returns: + Parsed configuration dictionary. + """ + try: + from omegaconf import OmegaConf + + cfg = OmegaConf.load(self.config_path) + self._config = OmegaConf.to_container(cfg, resolve=True) + except ImportError: + import yaml + + with open(self.config_path) as f: + self._config = yaml.safe_load(f) + return self._config + + def _collect_step_stats(self, train_stat: Dict[str, Any]) -> None: + """Extract speculative decoding stats from a training step result. + + Args: + train_stat: Dictionary of statistics from one training step. + """ + # Collect speculative decoding accept/draft token counts + accept_tokens = train_stat.get("spec_accept_token_num", 0) + draft_tokens = train_stat.get("spec_draft_token_num", 0) + + if draft_tokens > 0: + self.stats.total_accept_tokens += accept_tokens + self.stats.total_draft_tokens += draft_tokens + step_rate = accept_tokens / draft_tokens + self.stats.step_accept_rates.append(step_rate) + logger.info( + f"[SpecDecode] Step accept rate: {step_rate:.4f} " + f"({accept_tokens}/{draft_tokens})" + ) + + # Collect MTP loss if present + mtp_loss = train_stat.get("mtp_loss", None) + if mtp_loss is not None: + self.stats.mtp_losses.append(float(mtp_loss)) + logger.info(f"[MTPTrain] MTP loss: {mtp_loss:.6f}") + + # Collect rewards + reward = train_stat.get("reward/mean", train_stat.get("reward", None)) + if reward is not None: + self.stats.rewards.append(float(reward)) + + def run(self, max_steps: Optional[int] = None) -> SpecDecodeStats: + """Run the training loop and collect speculative decoding statistics. + + Args: + max_steps: Maximum number of training steps. None runs the full + config (total_train_epochs). + + Returns: + SpecDecodeStats with all collected metrics. + """ + config = self._load_config() + experiment_name = config.get("experiment_name", "test-spec-decode") + logger.info( + f"Starting MinimalSpecDecodePPOTrainer for '{experiment_name}' " + f"with config: {self.config_path}" + ) + + # Log speculative decoding configuration + sglang_cfg = config.get("sglang", {}) + actor_cfg = config.get("actor", {}) + logger.info( + f"Speculative config: algorithm={sglang_cfg.get('speculative_algorithm')}, " + f"num_steps={sglang_cfg.get('speculative_num_steps')}, " + f"num_draft_tokens={sglang_cfg.get('speculative_num_draft_tokens')}" + ) + logger.info( + f"MTP training: enabled={actor_cfg.get('enable_mtp_training', False)}, " + f"num_layers={actor_cfg.get('mtp_num_layers', 0)}, " + f"loss_scaling={actor_cfg.get('mtp_loss_scaling_factor', 0.0)}" + ) + + try: + from areal.trainer.rl_trainer import PPOTrainer + + self._trainer = PPOTrainer(config) + step = 0 + for train_stat in self._trainer.train(): + self._collect_step_stats(train_stat) + step += 1 + if max_steps is not None and step >= max_steps: + logger.info(f"Reached max_steps={max_steps}, stopping.") + break + except ImportError as e: + logger.warning( + f"Could not import PPOTrainer: {e}. " + f"Running in dry-run mode (config validation only)." + ) + except Exception as e: + logger.error(f"Training failed with error: {e}") + raise + + # Print summary + summary = self.stats.summary() + logger.info(f"=== Speculative Decoding E2E Summary ===") + logger.info(f" Total steps: {summary['num_steps']}") + logger.info(f" Overall accept rate: {summary['overall_accept_rate']:.4f}") + logger.info(f" Mean MTP loss: {summary['mean_mtp_loss']:.4f}") + logger.info(f" Mean reward: {summary['mean_reward']:.4f}") + + return self.stats + + +def main(): + """CLI entrypoint for running speculative decoding E2E tests.""" + parser = argparse.ArgumentParser( + description="Run speculative decoding E2E test with AReaL PPOTrainer" + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to experiment YAML config file", + ) + parser.add_argument( + "--max-steps", + type=int, + default=None, + help="Maximum training steps (default: run full config)", + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level", + ) + args = parser.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + trainer = MinimalSpecDecodePPOTrainer(config_path=args.config) + stats = trainer.run(max_steps=args.max_steps) + + summary = stats.summary() + print("\n=== Final Statistics ===") + for key, value in summary.items(): + if isinstance(value, list): + print(f" {key}: [{len(value)} entries]") + elif isinstance(value, float): + print(f" {key}: {value:.4f}") + else: + print(f" {key}: {value}") + + # Exit with error if no steps completed and we expected some + if summary["num_steps"] == 0 and args.max_steps != 0: + logger.warning("No training steps were completed.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/speculative_decoding/test_speculative_decoding.py b/tests/speculative_decoding/test_speculative_decoding.py new file mode 100644 index 0000000000..09059db6eb --- /dev/null +++ b/tests/speculative_decoding/test_speculative_decoding.py @@ -0,0 +1,411 @@ +"""Tests for speculative decoding (EAGLE) configuration and E2E training. + +This module contains: +- TestSpeculativeDecodingConfig: Unit tests for config field parsing and validation +- TestSpeculativeDecodingE2E: End-to-end tests for speculative decoding training + +Run unit tests: + pytest tests/speculative_decoding/test_speculative_decoding.py -v -k "Config" + +Run E2E tests (requires GPUs): + pytest tests/speculative_decoding/test_speculative_decoding.py -v -k "E2E" +""" + +import math +import os +from pathlib import Path +from typing import Any, Dict +from unittest.mock import MagicMock, patch + +import pytest +import yaml + +# --------------------------------------------------------------------------- +# Paths to test config files (relative to this file) +# --------------------------------------------------------------------------- +_TEST_DIR = Path(__file__).resolve().parent +_CONFIG_SPEC_ONLY = _TEST_DIR / "config_spec_only.yaml" +_CONFIG_SPEC_WITH_MTP = _TEST_DIR / "config_spec_with_mtp.yaml" + + +def _load_yaml(path: Path) -> Dict[str, Any]: + """Load a YAML file and return as dict (without OmegaConf resolution).""" + with open(path) as f: + return yaml.safe_load(f) + + +# ============================================================================ +# Unit Tests: Configuration Parsing and Validation +# ============================================================================ + + +class TestSpeculativeDecodingConfig: + """Unit tests for speculative decoding configuration fields.""" + + # ------------------------------------------------------------------ + # SGLang speculative decoding config fields + # ------------------------------------------------------------------ + + def test_sglang_config_has_speculative_fields(self): + """SGLangConfig dataclass should expose all speculative decoding fields.""" + from areal.api.cli_args import SGLangConfig + + cfg = SGLangConfig() + assert hasattr(cfg, "speculative_algorithm") + assert hasattr(cfg, "speculative_draft_model_path") + assert hasattr(cfg, "speculative_num_steps") + assert hasattr(cfg, "speculative_eagle_topk") + assert hasattr(cfg, "speculative_num_draft_tokens") + assert hasattr(cfg, "speculative_attention_mode") + + def test_sglang_config_defaults(self): + """Default values should disable speculative decoding.""" + from areal.api.cli_args import SGLangConfig + + cfg = SGLangConfig() + assert cfg.speculative_algorithm is None + assert cfg.speculative_draft_model_path is None + assert cfg.speculative_num_steps == 3 + assert cfg.speculative_eagle_topk == 1 + assert cfg.speculative_num_draft_tokens == 4 + assert cfg.speculative_attention_mode is None + + def test_sglang_config_eagle_values(self): + """SGLangConfig should accept EAGLE algorithm settings.""" + from areal.api.cli_args import SGLangConfig + + cfg = SGLangConfig( + speculative_algorithm="EAGLE", + speculative_num_steps=5, + speculative_eagle_topk=2, + speculative_num_draft_tokens=8, + ) + assert cfg.speculative_algorithm == "EAGLE" + assert cfg.speculative_num_steps == 5 + assert cfg.speculative_eagle_topk == 2 + assert cfg.speculative_num_draft_tokens == 8 + + def test_sglang_config_eagle3_values(self): + """SGLangConfig should accept EAGLE3 algorithm settings.""" + from areal.api.cli_args import SGLangConfig + + cfg = SGLangConfig(speculative_algorithm="EAGLE3") + assert cfg.speculative_algorithm == "EAGLE3" + + def test_sglang_config_draft_model_path(self): + """SGLangConfig should accept an external draft model path.""" + from areal.api.cli_args import SGLangConfig + + cfg = SGLangConfig( + speculative_algorithm="EAGLE", + speculative_draft_model_path="/models/eagle-draft", + ) + assert cfg.speculative_draft_model_path == "/models/eagle-draft" + + def test_sglang_config_enable_draft_weights_cpu_backup(self): + """SGLangConfig should expose enable_draft_weights_cpu_backup field.""" + from areal.api.cli_args import SGLangConfig + + cfg = SGLangConfig() + assert hasattr(cfg, "enable_draft_weights_cpu_backup") + + # ------------------------------------------------------------------ + # PPOActorConfig MTP training fields + # ------------------------------------------------------------------ + + def test_actor_config_has_mtp_fields(self): + """PPOActorConfig should expose MTP training fields.""" + from areal.api.cli_args import PPOActorConfig + + # PPOActorConfig requires certain fields; check class attributes + assert hasattr(PPOActorConfig, "enable_mtp_training") + assert hasattr(PPOActorConfig, "mtp_num_layers") + assert hasattr(PPOActorConfig, "mtp_loss_scaling_factor") + + def test_actor_config_mtp_defaults(self): + """MTP training should be disabled by default.""" + from areal.api.cli_args import PPOActorConfig + + # Access field defaults from the dataclass + import dataclasses + + fields = {f.name: f for f in dataclasses.fields(PPOActorConfig)} + assert fields["enable_mtp_training"].default is False + assert fields["mtp_num_layers"].default == 1 + assert fields["mtp_loss_scaling_factor"].default == 0.1 + + def test_actor_config_mtp_validation_num_layers_zero(self): + """Enabling MTP with mtp_num_layers=0 should raise ValueError.""" + from areal.api.cli_args import PPOActorConfig + + with pytest.raises(ValueError, match="mtp_num_layers must be > 0"): + PPOActorConfig( + enable_mtp_training=True, + mtp_num_layers=0, + mtp_loss_scaling_factor=0.1, + ) + + def test_actor_config_mtp_validation_scaling_factor_out_of_range(self): + """MTP loss scaling factor outside (0, 1.0] should raise ValueError.""" + from areal.api.cli_args import PPOActorConfig + + with pytest.raises(ValueError, match="mtp_loss_scaling_factor must be in"): + PPOActorConfig( + enable_mtp_training=True, + mtp_num_layers=1, + mtp_loss_scaling_factor=1.5, + ) + + def test_actor_config_mtp_validation_scaling_factor_zero(self): + """MTP loss scaling factor of 0 should raise ValueError.""" + from areal.api.cli_args import PPOActorConfig + + with pytest.raises(ValueError, match="mtp_loss_scaling_factor must be in"): + PPOActorConfig( + enable_mtp_training=True, + mtp_num_layers=1, + mtp_loss_scaling_factor=0.0, + ) + + def test_actor_config_mtp_validation_negative_layers(self): + """Negative mtp_num_layers should raise ValueError.""" + from areal.api.cli_args import PPOActorConfig + + with pytest.raises(ValueError, match="mtp_num_layers must be > 0"): + PPOActorConfig( + enable_mtp_training=True, + mtp_num_layers=-1, + mtp_loss_scaling_factor=0.1, + ) + + # ------------------------------------------------------------------ + # MegatronEngineConfig MTP fields + # ------------------------------------------------------------------ + + def test_megatron_config_has_mtp_fields(self): + """MegatronEngineConfig should have MTP-related fields.""" + from areal.api.cli_args import MegatronEngineConfig + + assert hasattr(MegatronEngineConfig, "mtp_num_layers") + assert hasattr(MegatronEngineConfig, "mtp_loss_scaling_factor") + + def test_megatron_config_mtp_defaults(self): + """MegatronEngineConfig MTP defaults should be 0 / 0.1.""" + from areal.api.cli_args import MegatronEngineConfig + + import dataclasses + + fields = {f.name: f for f in dataclasses.fields(MegatronEngineConfig)} + assert fields["mtp_num_layers"].default == 0 + assert fields["mtp_loss_scaling_factor"].default == 0.1 + + # ------------------------------------------------------------------ + # YAML config file parsing + # ------------------------------------------------------------------ + + def test_spec_only_yaml_loads(self): + """config_spec_only.yaml should load without errors.""" + cfg = _load_yaml(_CONFIG_SPEC_ONLY) + assert cfg["experiment_name"] == "test-spec-decode-only" + assert cfg["sglang"]["speculative_algorithm"] == "EAGLE" + + def test_spec_only_yaml_mtp_disabled(self): + """config_spec_only.yaml should have MTP training disabled.""" + cfg = _load_yaml(_CONFIG_SPEC_ONLY) + assert cfg["actor"]["enable_mtp_training"] is False + + def test_spec_with_mtp_yaml_loads(self): + """config_spec_with_mtp.yaml should load without errors.""" + cfg = _load_yaml(_CONFIG_SPEC_WITH_MTP) + assert cfg["experiment_name"] == "test-spec-decode-mtp" + assert cfg["sglang"]["speculative_algorithm"] == "EAGLE" + + def test_spec_with_mtp_yaml_mtp_enabled(self): + """config_spec_with_mtp.yaml should have MTP training enabled.""" + cfg = _load_yaml(_CONFIG_SPEC_WITH_MTP) + assert cfg["actor"]["enable_mtp_training"] is True + assert cfg["actor"]["mtp_num_layers"] == 1 + assert cfg["actor"]["mtp_loss_scaling_factor"] == 0.1 + + def test_spec_with_mtp_yaml_megatron_mtp(self): + """config_spec_with_mtp.yaml should have Megatron MTP settings.""" + cfg = _load_yaml(_CONFIG_SPEC_WITH_MTP) + megatron_cfg = cfg["actor"]["megatron"] + assert megatron_cfg["mtp_num_layers"] == 1 + assert megatron_cfg["mtp_loss_scaling_factor"] == 0.1 + + def test_spec_with_mtp_yaml_draft_cpu_backup(self): + """config_spec_with_mtp.yaml should enable draft weights CPU backup.""" + cfg = _load_yaml(_CONFIG_SPEC_WITH_MTP) + assert cfg["sglang"]["enable_draft_weights_cpu_backup"] is True + + def test_spec_only_yaml_no_draft_cpu_backup(self): + """config_spec_only.yaml should not set draft weights CPU backup.""" + cfg = _load_yaml(_CONFIG_SPEC_ONLY) + assert "enable_draft_weights_cpu_backup" not in cfg["sglang"] + + +# ============================================================================ +# E2E Tests: Speculative Decoding Training +# ============================================================================ + + +def _has_gpus(min_count: int = 2) -> bool: + """Check if sufficient GPUs are available.""" + try: + import torch + + return torch.cuda.is_available() and torch.cuda.device_count() >= min_count + except ImportError: + return False + + +@pytest.mark.skipif(not _has_gpus(2), reason="Requires at least 2 GPUs") +class TestSpeculativeDecodingE2E: + """End-to-end tests for speculative decoding training. + + These tests require GPU resources and a model checkpoint. They verify + that the full training loop runs correctly with speculative decoding + enabled, producing valid statistics. + """ + + def test_spec_only_e2e(self): + """E2E test: EAGLE speculative decoding without MTP training. + + Verifies that: + 1. Training completes without errors + 2. Speculative decoding stats are collected + 3. Accept rate is within valid range [0, 1] + """ + from tests.speculative_decoding.entrypoint import ( + MinimalSpecDecodePPOTrainer, + ) + + trainer = MinimalSpecDecodePPOTrainer( + config_path=str(_CONFIG_SPEC_ONLY) + ) + stats = trainer.run(max_steps=2) + summary = stats.summary() + + # Verify stats were collected + assert summary["num_steps"] > 0, "Expected at least 1 training step" + + # Accept rate should be in valid range + if summary["total_draft_tokens"] > 0: + assert 0.0 <= summary["overall_accept_rate"] <= 1.0, ( + f"Accept rate {summary['overall_accept_rate']} out of range" + ) + + # MTP loss should NOT be present (MTP training disabled) + assert len(stats.mtp_losses) == 0, ( + "MTP losses should be empty when enable_mtp_training=False" + ) + + def test_spec_with_mtp_e2e(self): + """E2E test: EAGLE speculative decoding with MTP online training. + + Verifies that: + 1. Training completes without errors + 2. Speculative decoding stats are collected + 3. MTP loss is recorded and is finite + 4. Accept rate is within valid range [0, 1] + """ + from tests.speculative_decoding.entrypoint import ( + MinimalSpecDecodePPOTrainer, + ) + + trainer = MinimalSpecDecodePPOTrainer( + config_path=str(_CONFIG_SPEC_WITH_MTP) + ) + stats = trainer.run(max_steps=2) + summary = stats.summary() + + # Verify stats were collected + assert summary["num_steps"] > 0, "Expected at least 1 training step" + + # Accept rate should be in valid range + if summary["total_draft_tokens"] > 0: + assert 0.0 <= summary["overall_accept_rate"] <= 1.0, ( + f"Accept rate {summary['overall_accept_rate']} out of range" + ) + + # MTP loss should be present and finite + assert len(stats.mtp_losses) > 0, ( + "MTP losses should be recorded when enable_mtp_training=True" + ) + for loss in stats.mtp_losses: + assert math.isfinite(loss), f"MTP loss is not finite: {loss}" + + def test_spec_decode_rewards_collected(self): + """Verify that rewards are collected during speculative decoding training.""" + from tests.speculative_decoding.entrypoint import ( + MinimalSpecDecodePPOTrainer, + ) + + trainer = MinimalSpecDecodePPOTrainer( + config_path=str(_CONFIG_SPEC_WITH_MTP) + ) + stats = trainer.run(max_steps=3) + + # Rewards should be collected + assert len(stats.rewards) > 0, "Expected rewards to be collected" + for reward in stats.rewards: + assert math.isfinite(reward), f"Reward is not finite: {reward}" + + +# ============================================================================ +# Unit Tests: SpecDecodeStats helper class +# ============================================================================ + + +class TestSpecDecodeStats: + """Unit tests for the SpecDecodeStats dataclass.""" + + def test_empty_stats(self): + """Empty stats should return 0 accept rate and NaN losses.""" + from tests.speculative_decoding.entrypoint import SpecDecodeStats + + stats = SpecDecodeStats() + assert stats.overall_accept_rate == 0.0 + assert math.isnan(stats.mean_mtp_loss) + assert math.isnan(stats.mean_reward) + + def test_accept_rate_calculation(self): + """Accept rate should be accept_tokens / draft_tokens.""" + from tests.speculative_decoding.entrypoint import SpecDecodeStats + + stats = SpecDecodeStats(total_accept_tokens=75, total_draft_tokens=100) + assert stats.overall_accept_rate == pytest.approx(0.75) + + def test_mean_mtp_loss(self): + """Mean MTP loss should average all recorded losses.""" + from tests.speculative_decoding.entrypoint import SpecDecodeStats + + stats = SpecDecodeStats(mtp_losses=[1.0, 2.0, 3.0]) + assert stats.mean_mtp_loss == pytest.approx(2.0) + + def test_mean_reward(self): + """Mean reward should average all recorded rewards.""" + from tests.speculative_decoding.entrypoint import SpecDecodeStats + + stats = SpecDecodeStats(rewards=[0.5, 1.0, 1.5]) + assert stats.mean_reward == pytest.approx(1.0) + + def test_summary_keys(self): + """Summary dict should contain all expected keys.""" + from tests.speculative_decoding.entrypoint import SpecDecodeStats + + stats = SpecDecodeStats() + summary = stats.summary() + expected_keys = { + "total_accept_tokens", + "total_draft_tokens", + "overall_accept_rate", + "num_steps", + "step_accept_rates", + "mean_mtp_loss", + "mean_reward", + "mtp_losses", + "rewards", + } + assert set(summary.keys()) == expected_keys From 71cecba5fd3603504823ca6f0fd41f0436a658a7 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 13 Apr 2026 18:31:12 +0800 Subject: [PATCH 002/140] feat: add config --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 198 ++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 examples/math/gsm8k_grpo_megatron_mimo.yaml diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml new file mode 100644 index 0000000000..73a9d90778 --- /dev/null +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -0,0 +1,198 @@ +experiment_name: gsm8k-grpo-megatron +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 6 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + + +scheduler: + type: null + +rollout: + backend: "sglang:d1p1t2" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 128 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:d1p1t4" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /workspace/models/MiMo-7B-RL + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 3e-6 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + + # MTP Online Training - keeps EAGLE draft heads aligned with evolving policy + enable_mtp_training: true + mtp_num_layers: 1 + mtp_loss_scaling_factor: 0.1 + + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + + megatron: + mtp_num_layers: 1 + mtp_loss_scaling_factor: 0.1 + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.5 + + # EAGLE speculative decoding settings + speculative_algorithm: "EAGLE" + speculative_draft_model_path: null + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + speculative_attention_mode: null + enable_draft_weights_cpu_backup: true + + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + +# datasets +train_dataset: + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 128 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false From ac145b8f006aef6b694b653d15d16e6f38ada3b0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 13 Apr 2026 21:21:04 +0800 Subject: [PATCH 003/140] feat: fix log --- areal/engine/megatron_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 80e00dd998..a492e39bf5 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -133,6 +133,7 @@ def parameters(self, *args, **kwargs) -> Iterator[nn.Parameter]: class MegatronEngine(TrainEngine): def __init__(self, config: TrainEngineConfig): self.config = config + self.logger = logging.getLogger("[MegatronEngine]") self.hf_config: PretrainedConfig self.tf_config: TransformerConfig self.model: _MegatronModelList | None = None From c82b174f136aa2125a45bdcf638660c760ccd3f9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 13 Apr 2026 22:38:32 +0800 Subject: [PATCH 004/140] feat: fix --- areal/engine/megatron_engine.py | 13 +++--- areal/models/mcore/registry.py | 74 +++++++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index a492e39bf5..59d3ab9db8 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -341,14 +341,14 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self._check_and_apply_fp8_config() self._validate_fp8_consistency() - # Propagate MTP config to mcore_config for model creation + # Propagate MTP config to tf_config (TransformerConfig) for model creation if self.enable_mtp_training: - self.mcore_config.mtp_num_layers = self.mtp_num_layers - self.mcore_config.mtp_loss_scaling_factor = self.mtp_loss_scaling_factor - if hasattr(self.mcore_config, "mtp_detach_heads"): - self.mcore_config.mtp_detach_heads = self.mtp_detach_heads + self.tf_config.mtp_num_layers = self.mtp_num_layers + self.tf_config.mtp_loss_scaling_factor = self.mtp_loss_scaling_factor + if hasattr(self.tf_config, "mtp_detach_heads"): + self.tf_config.mtp_detach_heads = self.mtp_detach_heads self.logger.info( - f"[MTPTrain] Propagated MTP config to mcore_config: " + f"[MTPTrain] Propagated MTP config to tf_config: " f"mtp_num_layers={self.mtp_num_layers}, " f"mtp_loss_scaling_factor={self.mtp_loss_scaling_factor}, " f"mtp_detach_heads={self.mtp_detach_heads}" @@ -363,6 +363,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): bridge_type=self.bridge_cls, is_critic=self.config.is_critic, use_lora=self.config.use_lora, + enable_mtp=self.enable_mtp_training, ) self.model = _MegatronModelList(models) diff --git a/areal/models/mcore/registry.py b/areal/models/mcore/registry.py index 477b476545..5b29a585ee 100644 --- a/areal/models/mcore/registry.py +++ b/areal/models/mcore/registry.py @@ -100,6 +100,62 @@ def unwrap_to_gpt_model(model: torch.nn.Module) -> GPTModel: return _model +def _ensure_mtp_spec_compat(): + """Patch get_gpt_mtp_block_spec to gracefully handle TransformerConfig as spec. + + mbridge may pass a raw TransformerConfig object as the ``spec`` argument to + ``get_gpt_mtp_block_spec``, which only accepts ``ModuleSpec`` or + ``TransformerBlockSubmodules``. This monkey-patch transparently converts + ``TransformerConfig`` into the correct ``ModuleSpec`` so that mbridge works + without modification. + """ + try: + from megatron.core.models.gpt import gpt_layer_specs as _specs_mod + except ImportError: + logger.warning( + "[MTPCompat] Cannot import gpt_layer_specs from megatron.core; " + "skipping MTP spec compatibility patch." + ) + return + + if getattr(_specs_mod, "_areal_mtp_compat_patched", False): + return + + _orig_fn = _specs_mod.get_gpt_mtp_block_spec + + def _compat_get_gpt_mtp_block_spec( + config, spec, use_transformer_engine=True, **kwargs + ): + if isinstance(spec, TransformerConfig): + logger.info( + "[MTPCompat] Auto-converting TransformerConfig -> ModuleSpec " + "for get_gpt_mtp_block_spec (use_transformer_engine=%s).", + use_transformer_engine, + ) + _get_decoder = getattr(_specs_mod, "get_gpt_decoder_block_spec", None) + if _get_decoder is not None: + decoder_block_spec = _get_decoder( + config=config, use_transformer_engine=use_transformer_engine + ) + spec = decoder_block_spec.layer_specs[-1] + logger.info( + "[MTPCompat] Resolved spec via get_gpt_decoder_block_spec." + ) + elif use_transformer_engine: + spec = _specs_mod.get_gpt_layer_with_transformer_engine_spec() + logger.info( + "[MTPCompat] Resolved spec via get_gpt_layer_with_transformer_engine_spec." + ) + else: + spec = _specs_mod.get_gpt_layer_local_spec() + logger.info("[MTPCompat] Resolved spec via get_gpt_layer_local_spec.") + return _orig_fn(config, spec, use_transformer_engine, **kwargs) + + _specs_mod.get_gpt_mtp_block_spec = _compat_get_gpt_mtp_block_spec + _specs_mod._areal_mtp_compat_patched = True + logger.info("[MTPCompat] Patched get_gpt_mtp_block_spec for TransformerConfig compat.") + + # Model registry for different architectures def make_hf_and_mcore_config( hf_path: str, @@ -162,8 +218,18 @@ def make_mcore_model( bridge_type: str = "mbridge", is_critic: bool = False, use_lora: bool = False, + enable_mtp: bool = False, ) -> list[GPTModel | DDP]: if bridge is not None and bridge_type == "mbridge": + # Patch get_gpt_mtp_block_spec before mbridge calls it so that a + # TransformerConfig passed as ``spec`` is auto-converted to the + # correct ModuleSpec type expected by megatron-core. + if enable_mtp: + _ensure_mtp_spec_compat() + logger.info( + "[MTPTrain] Applied MTP spec compatibility patch before mbridge model creation." + ) + models = bridge.get_model( # TODO: Add DDP options when supporting training wrap_with_ddp=mcore_config.wrap_with_ddp, @@ -281,14 +347,16 @@ def make_mcore_model( mtp_num_layers = getattr(tf_config, "mtp_num_layers", 0) if mtp_num_layers > 0: try: - from megatron.core.models.gpt.gpt_layer_specs import get_mtp_block_spec - mtp_block_spec = get_mtp_block_spec(tf_config, transformer_layer_spec) + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec + mtp_block_spec = get_gpt_mtp_block_spec( + tf_config, transformer_layer_spec, use_transformer_engine=True + ) logger.info( f"[MTPTrain] Created MTP block spec with {mtp_num_layers} layers" ) except ImportError: logger.warning( - "[MTPTrain] Cannot import get_mtp_block_spec from megatron.core. " + "[MTPTrain] Cannot import get_gpt_mtp_block_spec from megatron.core. " "MTP layers will not be created. Ensure megatron-core >= 0.11.0." ) rope_scaling_args = {} From 328236636e168c231498bc1708f266fbb2a983fd Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 13 Apr 2026 22:58:05 +0800 Subject: [PATCH 005/140] feat: fix --- areal/models/mcore/registry.py | 123 ++++++++++++++++++++++++--------- 1 file changed, 90 insertions(+), 33 deletions(-) diff --git a/areal/models/mcore/registry.py b/areal/models/mcore/registry.py index 5b29a585ee..a614f78f48 100644 --- a/areal/models/mcore/registry.py +++ b/areal/models/mcore/registry.py @@ -101,13 +101,31 @@ def unwrap_to_gpt_model(model: torch.nn.Module) -> GPTModel: def _ensure_mtp_spec_compat(): - """Patch get_gpt_mtp_block_spec to gracefully handle TransformerConfig as spec. + """Patch MTP block-spec functions to gracefully handle TransformerConfig as *spec*. - mbridge may pass a raw TransformerConfig object as the ``spec`` argument to - ``get_gpt_mtp_block_spec``, which only accepts ``ModuleSpec`` or - ``TransformerBlockSubmodules``. This monkey-patch transparently converts - ``TransformerConfig`` into the correct ``ModuleSpec`` so that mbridge works - without modification. + **Why multi-level patching is needed** + + ``mbridge.models.mimo`` does:: + + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec + + at module load time, which creates a *local* binding to the original + function object. Simply replacing ``gpt_layer_specs.get_gpt_mtp_block_spec`` + does NOT affect that already-bound local reference — the original function + will still be called by mimo. + + However, the original ``get_gpt_mtp_block_spec`` internally calls + ``get_gpt_mtp_block_spec_for_backend`` through the module's **global + namespace**, which IS resolved at call time. Therefore we apply patches + at three levels for maximum robustness: + + 1. ``get_gpt_mtp_block_spec_for_backend`` on the module — catches calls + coming through *any* import path (including mimo's local reference). + This is the **critical** patch that actually fixes the bug. + 2. ``get_gpt_mtp_block_spec`` on the module — catches future callers that + access it via ``gpt_layer_specs.get_gpt_mtp_block_spec``. + 3. The local reference inside ``mbridge.models.mimo`` (if importable) — + belt-and-suspenders for the direct ``from-import`` case. """ try: from megatron.core.models.gpt import gpt_layer_specs as _specs_mod @@ -121,39 +139,78 @@ def _ensure_mtp_spec_compat(): if getattr(_specs_mod, "_areal_mtp_compat_patched", False): return + # ----- helper: convert TransformerConfig → proper ModuleSpec ----- + def _convert_spec_if_needed(config, spec, use_transformer_engine=True): + if not isinstance(spec, TransformerConfig): + return spec + logger.info( + "[MTPCompat] Auto-converting TransformerConfig -> ModuleSpec " + "for get_gpt_mtp_block_spec (use_transformer_engine=%s).", + use_transformer_engine, + ) + _get_decoder = getattr(_specs_mod, "get_gpt_decoder_block_spec", None) + if _get_decoder is not None: + decoder_block_spec = _get_decoder( + config=config, use_transformer_engine=use_transformer_engine + ) + spec = decoder_block_spec.layer_specs[-1] + logger.info( + "[MTPCompat] Resolved spec via get_gpt_decoder_block_spec." + ) + elif use_transformer_engine: + spec = _specs_mod.get_gpt_layer_with_transformer_engine_spec() + logger.info( + "[MTPCompat] Resolved spec via " + "get_gpt_layer_with_transformer_engine_spec." + ) + else: + spec = _specs_mod.get_gpt_layer_local_spec() + logger.info("[MTPCompat] Resolved spec via get_gpt_layer_local_spec.") + return spec + + # get_gpt_mtp_block_spec_for_backend --- + # This is the lowest-level function that validates the spec type. + # Because the original get_gpt_mtp_block_spec resolves this name + # through the module's global dict at call time, patching here + # intercepts ALL callers — including mimo's from-imported reference. + _orig_backend_fn = _specs_mod.get_gpt_mtp_block_spec_for_backend + + def _compat_backend(config, spec, use_transformer_engine=True, **kwargs): + spec = _convert_spec_if_needed(config, spec, use_transformer_engine) + return _orig_backend_fn(config, spec, use_transformer_engine, **kwargs) + + _specs_mod.get_gpt_mtp_block_spec_for_backend = _compat_backend + + # --- Patch 2: get_gpt_mtp_block_spec (top-level entry point) --- _orig_fn = _specs_mod.get_gpt_mtp_block_spec - def _compat_get_gpt_mtp_block_spec( - config, spec, use_transformer_engine=True, **kwargs - ): - if isinstance(spec, TransformerConfig): + def _compat_fn(config, spec, use_transformer_engine=True, **kwargs): + spec = _convert_spec_if_needed(config, spec, use_transformer_engine) + return _orig_fn(config, spec, use_transformer_engine, **kwargs) + + _specs_mod.get_gpt_mtp_block_spec = _compat_fn + + # mbridge.models.mimo local reference (if available) --- + try: + import mbridge.models.mimo as _mimo_mod + + if hasattr(_mimo_mod, "get_gpt_mtp_block_spec"): + _mimo_mod.get_gpt_mtp_block_spec = _compat_fn logger.info( - "[MTPCompat] Auto-converting TransformerConfig -> ModuleSpec " - "for get_gpt_mtp_block_spec (use_transformer_engine=%s).", - use_transformer_engine, + "[MTPCompat] Also patched mbridge.models.mimo." + "get_gpt_mtp_block_spec direct reference." ) - _get_decoder = getattr(_specs_mod, "get_gpt_decoder_block_spec", None) - if _get_decoder is not None: - decoder_block_spec = _get_decoder( - config=config, use_transformer_engine=use_transformer_engine - ) - spec = decoder_block_spec.layer_specs[-1] - logger.info( - "[MTPCompat] Resolved spec via get_gpt_decoder_block_spec." - ) - elif use_transformer_engine: - spec = _specs_mod.get_gpt_layer_with_transformer_engine_spec() - logger.info( - "[MTPCompat] Resolved spec via get_gpt_layer_with_transformer_engine_spec." - ) - else: - spec = _specs_mod.get_gpt_layer_local_spec() - logger.info("[MTPCompat] Resolved spec via get_gpt_layer_local_spec.") - return _orig_fn(config, spec, use_transformer_engine, **kwargs) + except (ImportError, AttributeError): + logger.info( + "[MTPCompat] mbridge.models.mimo not importable; " + "relying on backend-level patch only." + ) - _specs_mod.get_gpt_mtp_block_spec = _compat_get_gpt_mtp_block_spec _specs_mod._areal_mtp_compat_patched = True - logger.info("[MTPCompat] Patched get_gpt_mtp_block_spec for TransformerConfig compat.") + logger.info( + "[MTPCompat] Patched get_gpt_mtp_block_spec AND " + "get_gpt_mtp_block_spec_for_backend for TransformerConfig compat." + ) # Model registry for different architectures From fffe03e1bc767e1d8a8d9c0ab6c38e5a187ade02 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 13 Apr 2026 23:12:59 +0800 Subject: [PATCH 006/140] feat: fix --- areal/models/mcore/registry.py | 37 +++++++++++++++++----------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/areal/models/mcore/registry.py b/areal/models/mcore/registry.py index a614f78f48..2240c03c7b 100644 --- a/areal/models/mcore/registry.py +++ b/areal/models/mcore/registry.py @@ -140,53 +140,54 @@ def _ensure_mtp_spec_compat(): return # ----- helper: convert TransformerConfig → proper ModuleSpec ----- - def _convert_spec_if_needed(config, spec, use_transformer_engine=True): + def _convert_spec_if_needed(config, spec): + """Convert TransformerConfig to a proper ModuleSpec for MTP block spec. + + Uses get_gpt_decoder_block_spec (preferred) or falls back to + get_gpt_layer_with_transformer_engine_spec / get_gpt_layer_local_spec. + Always uses transformer_engine=True since MTP with TE is the common case. + """ if not isinstance(spec, TransformerConfig): return spec logger.info( "[MTPCompat] Auto-converting TransformerConfig -> ModuleSpec " - "for get_gpt_mtp_block_spec (use_transformer_engine=%s).", - use_transformer_engine, + "for get_gpt_mtp_block_spec." ) _get_decoder = getattr(_specs_mod, "get_gpt_decoder_block_spec", None) if _get_decoder is not None: decoder_block_spec = _get_decoder( - config=config, use_transformer_engine=use_transformer_engine + config=config, use_transformer_engine=True ) spec = decoder_block_spec.layer_specs[-1] logger.info( "[MTPCompat] Resolved spec via get_gpt_decoder_block_spec." ) - elif use_transformer_engine: + else: spec = _specs_mod.get_gpt_layer_with_transformer_engine_spec() logger.info( "[MTPCompat] Resolved spec via " "get_gpt_layer_with_transformer_engine_spec." ) - else: - spec = _specs_mod.get_gpt_layer_local_spec() - logger.info("[MTPCompat] Resolved spec via get_gpt_layer_local_spec.") return spec # get_gpt_mtp_block_spec_for_backend --- - # This is the lowest-level function that validates the spec type. - # Because the original get_gpt_mtp_block_spec resolves this name - # through the module's global dict at call time, patching here - # intercepts ALL callers — including mimo's from-imported reference. + # Signature: (config, spec, backend, vp_stage=None, pp_rank=None) + # The 3rd param is `backend` (BackendSpecProvider), NOT use_transformer_engine. + # We only intercept config + spec; all other args pass through unchanged. _orig_backend_fn = _specs_mod.get_gpt_mtp_block_spec_for_backend - def _compat_backend(config, spec, use_transformer_engine=True, **kwargs): - spec = _convert_spec_if_needed(config, spec, use_transformer_engine) - return _orig_backend_fn(config, spec, use_transformer_engine, **kwargs) + def _compat_backend(config, spec, *args, **kwargs): + spec = _convert_spec_if_needed(config, spec) + return _orig_backend_fn(config, spec, *args, **kwargs) _specs_mod.get_gpt_mtp_block_spec_for_backend = _compat_backend # --- Patch 2: get_gpt_mtp_block_spec (top-level entry point) --- _orig_fn = _specs_mod.get_gpt_mtp_block_spec - def _compat_fn(config, spec, use_transformer_engine=True, **kwargs): - spec = _convert_spec_if_needed(config, spec, use_transformer_engine) - return _orig_fn(config, spec, use_transformer_engine, **kwargs) + def _compat_fn(config, spec, *args, **kwargs): + spec = _convert_spec_if_needed(config, spec) + return _orig_fn(config, spec, *args, **kwargs) _specs_mod.get_gpt_mtp_block_spec = _compat_fn From 76999a3c280f5e93c43dd5add336cb3dc1753663 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 13 Apr 2026 23:42:47 +0800 Subject: [PATCH 007/140] feat: fix --- areal/workflow/rlvr.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index db5e59ed7a..d4e5712ca5 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -152,14 +152,11 @@ async def _collect_samples( f"2) Checking MTP layer training status; " f"3) Reducing speculative_num_steps." ) - elif accept_rate < 0.5: - logger.info( - f"[SpecDec] Accept rate: {accept_rate:.4f} " - f"(accept={resp.spec_accept_token_num}, draft={resp.spec_draft_token_num})" - ) else: - logger.info( - f"[SpecDec] Good accept rate: {accept_rate:.4f} " + # Per-sample accept rates are recorded via stats_tracker above; + # log at DEBUG to avoid flooding stdout during large rollouts. + logger.debug( + f"[SpecDec] Accept rate: {accept_rate:.4f} " f"(accept={resp.spec_accept_token_num}, draft={resp.spec_draft_token_num})" ) From 468216864e7c47f69e469bdfeeffd192fffabc3c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 00:30:35 +0800 Subject: [PATCH 008/140] feat: fix --- areal/engine/megatron_engine.py | 59 +++++++------------ .../megatron_utils/packed_context_parallel.py | 38 ++++++++---- 2 files changed, 45 insertions(+), 52 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 59d3ab9db8..d056564d4f 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -865,51 +865,32 @@ def forward_step(batch_iter, model): mb_input.padded_mb.update(tree_kwargs) tree_attn_keys = list(tree_kwargs.keys()) - # Build MTP kwargs if MTP training is enabled + # Build MTP kwargs if MTP training is enabled AND this is a + # training pass (not forward-only inference like compute_logp). + # + # Megatron-core 0.16.x GPTModel.forward() computes MTP loss + # internally in _postprocess() when `labels` and `loss_mask` + # are provided. These must be passed as top-level kwargs to + # GPTModel.forward(), NOT inside extra_block_kwargs — because + # GPTModel unpacks extra_block_kwargs via **-splat into + # TransformerBlock.forward() which does not accept them. + # + # We pass labels/loss_mask through extra_block_kwargs as a + # transport dict; packed_context_parallel_forward extracts + # them before the actual model() call. extra_block_kwargs = None - if self.enable_mtp_training: + if self.enable_mtp_training and not forward_only: mtp_labels = mb_input.padded_mb["input_ids"] + mtp_loss_mask = mb_input.padded_mb.get("loss_mask", None) - loss_mask = mb_input.padded_mb.get("loss_mask", None) - mtp_loss_mask = None - if loss_mask is not None: - cu_seqlens = mb_input.padded_mb.get("cu_seqlens", None) - if cu_seqlens is not None: - mask_1 = self._roll_tensor_packed( - loss_mask, shift=-1, cu_seqlens=cu_seqlens - ) - mask_2 = self._roll_tensor_packed( - mask_1, shift=-1, cu_seqlens=cu_seqlens - ) - else: - mask_1 = torch.roll(loss_mask, shifts=-1, dims=-1) - mask_1[..., -1] = 0 - mask_2 = torch.roll(mask_1, shifts=-1, dims=-1) - mask_2[..., -1] = 0 - mtp_loss_mask = mask_1 * mask_2 - valid_mtp_tokens = mtp_loss_mask.sum().item() - total_mtp_tokens = mtp_loss_mask.numel() - self.logger.info( - f"[MTPTrain] MTP loss mask: valid_tokens={valid_mtp_tokens}, " - f"total_tokens={total_mtp_tokens}, " - f"mask_ratio={valid_mtp_tokens / max(total_mtp_tokens, 1):.4f}" - ) - else: - self.logger.warning( - "[MTPTrain] loss_mask is None; MTP loss will be computed over " - "all positions including padding. This may lead to incorrect " - "MTP loss values. Ensure loss_mask is provided in the input." - ) - - mtp_kwargs = {"mtp_labels": mtp_labels} + extra_block_kwargs = {"labels": mtp_labels} if mtp_loss_mask is not None: - mtp_kwargs["mtp_loss_mask"] = mtp_loss_mask - extra_block_kwargs = {"mtp_kwargs": mtp_kwargs} + extra_block_kwargs["loss_mask"] = mtp_loss_mask - self.logger.info( - f"[MTPTrain] Forward step: mtp_labels shape={mtp_labels.shape}, " + self.logger.debug( + f"[MTPTrain] Forward step: labels shape={mtp_labels.shape}, " f"dtype={mtp_labels.dtype}, " - f"has_mtp_loss_mask={mtp_loss_mask is not None}, " + f"has_loss_mask={mtp_loss_mask is not None}, " f"mtp_num_layers={self.mtp_num_layers}" ) output = packed_context_parallel_forward( diff --git a/areal/engine/megatron_utils/packed_context_parallel.py b/areal/engine/megatron_utils/packed_context_parallel.py index 0653f43cb5..6934e281dd 100644 --- a/areal/engine/megatron_utils/packed_context_parallel.py +++ b/areal/engine/megatron_utils/packed_context_parallel.py @@ -145,19 +145,17 @@ def packed_context_parallel_forward( ) input_ids = input_ids.contiguous() - # Also split MTP labels with the same CP logic if present - if extra_block_kwargs and "mtp_kwargs" in extra_block_kwargs: - mtp_kwargs = extra_block_kwargs["mtp_kwargs"] - if "mtp_labels" in mtp_kwargs: - mtp_labels_split, _ = preprocess_packed_seqs_context_parallel( - mtp_kwargs["mtp_labels"], cu_seqlens - ) - mtp_kwargs["mtp_labels"] = mtp_labels_split.contiguous() - if "mtp_loss_mask" in mtp_kwargs: - mtp_mask_split, _ = preprocess_packed_seqs_context_parallel( - mtp_kwargs["mtp_loss_mask"], cu_seqlens - ) - mtp_kwargs["mtp_loss_mask"] = mtp_mask_split.contiguous() + # Split MTP labels / loss_mask with the same CP logic if present. + # These tensors are passed via extra_block_kwargs and will be + # forwarded to GPTModel.forward() as `labels` and `loss_mask` + # so that megatron-core computes MTP loss internally. + if extra_block_kwargs: + for key in ("labels", "loss_mask"): + if key in extra_block_kwargs: + split_val, _ = preprocess_packed_seqs_context_parallel( + extra_block_kwargs[key], cu_seqlens + ) + extra_block_kwargs[key] = split_val.contiguous() # Pass tree_triton_data as attention_mask if present (for Triton tree attention) # Otherwise use the attention_mask from input (could be dense tensor for flex attention) @@ -165,6 +163,19 @@ def packed_context_parallel_forward( tree_triton_data if tree_triton_data is not None else attention_mask ) + # Extract model-level forward kwargs (labels, loss_mask) from + # extra_block_kwargs. These must be passed as top-level keyword + # arguments to GPTModel.forward() — NOT inside extra_block_kwargs, + # because GPTModel unpacks extra_block_kwargs via **-splat into + # TransformerBlock.forward() which does not accept them. + model_fwd_kwargs: dict[str, Any] = {} + if extra_block_kwargs: + for key in ("labels", "loss_mask"): + if key in extra_block_kwargs: + model_fwd_kwargs[key] = extra_block_kwargs.pop(key) + if not extra_block_kwargs: + extra_block_kwargs = None + try: output = model( input_ids=input_ids, @@ -172,6 +183,7 @@ def packed_context_parallel_forward( position_ids=position_ids, packed_seq_params=packed_seq_params, extra_block_kwargs=extra_block_kwargs, + **model_fwd_kwargs, ) except Exception as e: raise RuntimeError( From 8a9e8d104daf736a0654684ee8d1532b6d0fd209 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 10:46:39 +0800 Subject: [PATCH 009/140] feat: fix code --- areal/engine/megatron_engine.py | 73 ++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index d056564d4f..5ee83e2b79 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -865,37 +865,54 @@ def forward_step(batch_iter, model): mb_input.padded_mb.update(tree_kwargs) tree_attn_keys = list(tree_kwargs.keys()) - # Build MTP kwargs if MTP training is enabled AND this is a - # training pass (not forward-only inference like compute_logp). + # ---- MTP safety: disable in GPTModel._postprocess() ---- + # megatron-core 0.16.x GPTModel._postprocess() has a known + # issue: the MTP loss block calls `labels.clone()` BEFORE + # the `if labels is None: return logits` early-return check. + # This causes AttributeError during inference (labels=None). + # Additionally, when labels ARE provided, _postprocess() + # returns loss (not logits), which is incompatible with + # AReaL's custom loss computation pipeline. # - # Megatron-core 0.16.x GPTModel.forward() computes MTP loss - # internally in _postprocess() when `labels` and `loss_mask` - # are provided. These must be passed as top-level kwargs to - # GPTModel.forward(), NOT inside extra_block_kwargs — because - # GPTModel unpacks extra_block_kwargs via **-splat into - # TransformerBlock.forward() which does not accept them. - # - # We pass labels/loss_mask through extra_block_kwargs as a - # transport dict; packed_context_parallel_forward extracts - # them before the actual model() call. + # Workaround: temporarily disable MTP on the unwrapped + # GPTModel so that _postprocess() skips MTP entirely and + # always returns logits. MTP parameters are still part of + # the model and will be saved/loaded with checkpoints. extra_block_kwargs = None - if self.enable_mtp_training and not forward_only: - mtp_labels = mb_input.padded_mb["input_ids"] - mtp_loss_mask = mb_input.padded_mb.get("loss_mask", None) - - extra_block_kwargs = {"labels": mtp_labels} - if mtp_loss_mask is not None: - extra_block_kwargs["loss_mask"] = mtp_loss_mask - - self.logger.debug( - f"[MTPTrain] Forward step: labels shape={mtp_labels.shape}, " - f"dtype={mtp_labels.dtype}, " - f"has_loss_mask={mtp_loss_mask is not None}, " - f"mtp_num_layers={self.mtp_num_layers}" + _mtp_restore = None + if self.enable_mtp_training: + _unwrapped = model + while hasattr(_unwrapped, 'module'): + _unwrapped = _unwrapped.module + _saved_mtp = getattr(_unwrapped, 'mtp', None) + _saved_mtp_layers = getattr( + _unwrapped.config, 'mtp_num_layers', None ) - output = packed_context_parallel_forward( - model, mb_input.padded_mb, extra_block_kwargs=extra_block_kwargs - ) + if _saved_mtp is not None or _saved_mtp_layers is not None: + _unwrapped.mtp = None + _unwrapped.config.mtp_num_layers = None + _mtp_restore = ( + _unwrapped, _saved_mtp, _saved_mtp_layers, + ) + self.logger.debug( + f"[MTPTrain] Temporarily disabled MTP in " + f"_postprocess (forward_only={forward_only}, " + f"saved_mtp_num_layers={_saved_mtp_layers})" + ) + + try: + output = packed_context_parallel_forward( + model, mb_input.padded_mb, + extra_block_kwargs=extra_block_kwargs, + ) + finally: + if _mtp_restore is not None: + _uw, _sm, _sl = _mtp_restore + _uw.mtp = _sm + _uw.config.mtp_num_layers = _sl + self.logger.debug( + "[MTPTrain] Restored MTP after forward pass" + ) # Release tree attention metadata after forward pass for key in tree_attn_keys: From e9557082fc751b07c1959457d9401bad532ef3fe Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 11:19:01 +0800 Subject: [PATCH 010/140] feat: add log --- areal/workflow/rlvr.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index d4e5712ca5..d00be231fc 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -153,9 +153,7 @@ async def _collect_samples( f"3) Reducing speculative_num_steps." ) else: - # Per-sample accept rates are recorded via stats_tracker above; - # log at DEBUG to avoid flooding stdout during large rollouts. - logger.debug( + logger.info( f"[SpecDec] Accept rate: {accept_rate:.4f} " f"(accept={resp.spec_accept_token_num}, draft={resp.spec_draft_token_num})" ) From 985cb26a930e3c893a9f5bfe54c044b1126dd459 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 11:29:53 +0800 Subject: [PATCH 011/140] fet: fix --- areal/engine/megatron_engine.py | 46 +++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 5ee83e2b79..6257a1cbe6 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -866,18 +866,23 @@ def forward_step(batch_iter, model): tree_attn_keys = list(tree_kwargs.keys()) # ---- MTP safety: disable in GPTModel._postprocess() ---- - # megatron-core 0.16.x GPTModel._postprocess() has a known - # issue: the MTP loss block calls `labels.clone()` BEFORE - # the `if labels is None: return logits` early-return check. - # This causes AttributeError during inference (labels=None). - # Additionally, when labels ARE provided, _postprocess() - # returns loss (not logits), which is incompatible with - # AReaL's custom loss computation pipeline. + # megatron-core 0.16.x GPTModel._postprocess() has known + # issues that crash both inference and training paths: # - # Workaround: temporarily disable MTP on the unwrapped - # GPTModel so that _postprocess() skips MTP entirely and - # always returns logits. MTP parameters are still part of - # the model and will be saved/loaded with checkpoints. + # 1) `if mtp_in_postprocess:` calls `self.mtp(...)`. + # `mtp_in_postprocess` comes from `self.mtp_process`, + # a bool cached in __init__ from `mtp_block_spec is + # not None`. Setting self.mtp=None is NOT enough — + # self.mtp_process must also be set to False. + # 2) `if self.config.mtp_num_layers is not None:` calls + # `labels.clone()` BEFORE the `labels is None` early + # return, crashing during inference. + # 3) When labels IS provided, _postprocess() returns loss + # (not logits), incompatible with AReaL's loss pipeline. + # + # Workaround: temporarily disable all three MTP flags on + # the unwrapped GPTModel so _postprocess() skips MTP + # entirely and always returns logits. extra_block_kwargs = None _mtp_restore = None if self.enable_mtp_training: @@ -885,18 +890,30 @@ def forward_step(batch_iter, model): while hasattr(_unwrapped, 'module'): _unwrapped = _unwrapped.module _saved_mtp = getattr(_unwrapped, 'mtp', None) + _saved_mtp_process = getattr( + _unwrapped, 'mtp_process', None + ) _saved_mtp_layers = getattr( _unwrapped.config, 'mtp_num_layers', None ) - if _saved_mtp is not None or _saved_mtp_layers is not None: + if ( + _saved_mtp is not None + or _saved_mtp_process + or _saved_mtp_layers is not None + ): _unwrapped.mtp = None + _unwrapped.mtp_process = False _unwrapped.config.mtp_num_layers = None _mtp_restore = ( - _unwrapped, _saved_mtp, _saved_mtp_layers, + _unwrapped, + _saved_mtp, + _saved_mtp_process, + _saved_mtp_layers, ) self.logger.debug( f"[MTPTrain] Temporarily disabled MTP in " f"_postprocess (forward_only={forward_only}, " + f"saved_mtp_process={_saved_mtp_process}, " f"saved_mtp_num_layers={_saved_mtp_layers})" ) @@ -907,8 +924,9 @@ def forward_step(batch_iter, model): ) finally: if _mtp_restore is not None: - _uw, _sm, _sl = _mtp_restore + _uw, _sm, _sp, _sl = _mtp_restore _uw.mtp = _sm + _uw.mtp_process = _sp _uw.config.mtp_num_layers = _sl self.logger.debug( "[MTPTrain] Restored MTP after forward pass" From 6670f3dc38e935983abd7f29cd26e0454b921997 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 14:00:53 +0800 Subject: [PATCH 012/140] feat: fix --- areal/engine/megatron_utils/megatron.py | 85 +++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/areal/engine/megatron_utils/megatron.py b/areal/engine/megatron_utils/megatron.py index 55d846fb34..5d08d41f24 100644 --- a/areal/engine/megatron_utils/megatron.py +++ b/areal/engine/megatron_utils/megatron.py @@ -815,6 +815,90 @@ def convert_bailingmoe_to_hf( raise ValueError(f"Unknown parameter name: {name}") +def convert_mimo_to_hf( + tf_config: TransformerConfig, + name: str, + param: Parameter | Tensor | FP8BlockwiseTensorHelper, +): + """Convert MiMo model parameters from Megatron to HuggingFace format. + + MiMo extends Qwen2 with MTP (Multi-Token Prediction) layers. + Non-MTP parameters are delegated to the Qwen2 converter. + """ + if "mtp" in name: + return _convert_mimo_mtp_param(tf_config, name, param) + + return convert_qwen2_to_hf(tf_config, name, param) + + +def _convert_mimo_mtp_param( + tf_config: TransformerConfig, + name: str, + param: Parameter | Tensor | FP8BlockwiseTensorHelper, +): + """Convert MiMo MTP layer parameters from Megatron to HuggingFace format. + + MTP layers in MiMo contain: + - LayerNorms (enorm/token_layernorm, hnorm/hidden_layernorm, final_layernorm) + - Input projection (eh_proj/input_proj) with column-half swap + - Self attention (reuses Qwen2 attention structure via transformer_layer) + - MLP (reuses Qwen2 MLP structure via transformer_layer) + + Handles two naming patterns produced by different megatron-core versions: + - module.module.mtp.layers.{idx}.{component} (mcore native) + - module.module.decoder.mtp_layers.{idx}.{component} + """ + mtp_pattern1 = r"module\.module\.mtp\.layers\.(\d+)\.(.+)" + mtp_pattern2 = r"module\.module\.decoder\.mtp_layers\.(\d+)\.(.+)" + + match = re.match(mtp_pattern1, name) + if match is None: + match = re.match(mtp_pattern2, name) + + if match is None: + raise ValueError(f"Invalid MiMo MTP parameter name: {name}") + + layer_idx, component = match.groups() + + # Direct mappings for MTP-specific components (Megatron -> HF) + direct_mappings = { + "enorm.weight": f"model.mtp_layers.{layer_idx}.token_layernorm.weight", + "hnorm.weight": f"model.mtp_layers.{layer_idx}.hidden_layernorm.weight", + "eh_proj.weight": f"model.mtp_layers.{layer_idx}.input_proj.weight", + "final_layernorm.weight": f"model.mtp_layers.{layer_idx}.final_layernorm.weight", + } + + # MiMo-specific: swap column halves for eh_proj weight + if component == "eh_proj.weight": + first_half, second_half = param.chunk(2, dim=1) + param = torch.cat([second_half, first_half], dim=1) + + # Check direct mappings first + if component in direct_mappings: + return [(direct_mappings[component], param)] + + # Handle transformer_layer components by delegating to Qwen2 converter + if component.startswith("transformer_layer."): + transformer_component = component[len("transformer_layer."):] + + # Create proxy name for reusing existing Qwen2 conversion + proxy_name = f"module.module.decoder.layers.{layer_idx}.{transformer_component}" + + # Use existing convert_qwen2_to_hf for transformer components + results = convert_qwen2_to_hf(tf_config, proxy_name, param) + + # Replace model.layers with model.mtp_layers in results + converted_results = [] + for hf_name, hf_param in results: + hf_name = hf_name.replace( + f"model.layers.{layer_idx}", f"model.mtp_layers.{layer_idx}" + ) + converted_results.append((hf_name, hf_param)) + + return converted_results + + raise ValueError(f"Unknown MiMo MTP component: {component} in {name}") + # Adapted from slime # A registry for conversion functions is more extensible. @@ -828,6 +912,7 @@ def convert_bailingmoe_to_hf( "bailing_moe_v2": convert_bailingmoe_to_hf, "bailing_moe_linear": convert_bailingmoe_to_hf, "bailing_hybrid": convert_bailingmoe_to_hf, + "mimo": convert_mimo_to_hf, } From 6fdcbc7757afed2f9a43175f8c32a7f82755c573 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 14:48:47 +0800 Subject: [PATCH 013/140] feat: fix code --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 73a9d90778..1f126a21a1 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -40,7 +40,7 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:d1p1t4" + backend: "megatron:d1p1t6" experiment_name: ${experiment_name} trial_name: ${trial_name} path: /workspace/models/MiMo-7B-RL @@ -119,7 +119,7 @@ sglang: dtype: ${actor.dtype} max_running_requests: null context_length: 32768 - mem_fraction_static: 0.5 + mem_fraction_static: 0.4 # EAGLE speculative decoding settings speculative_algorithm: "EAGLE" From 261c9c7206d434f4dda165cff4528fed3261650d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 18:18:17 +0800 Subject: [PATCH 014/140] feat: fix mtp --- areal/engine/megatron_engine.py | 41 ++++++++-- areal/models/mcore/registry.py | 131 ++++++++++++++++++++++++++++---- 2 files changed, 151 insertions(+), 21 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 6257a1cbe6..36fdfd76bf 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -449,23 +449,48 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): if self.enable_mtp_training and not self._mtp_layers_verified: mtp_param_count = 0 + mtp_param_names = [] for module in modules: for name, param in module.named_parameters(): if ".mtp." in name: mtp_param_count += param.numel() + if len(mtp_param_names) < 5: + mtp_param_names.append(name) + + # With pipeline parallelism, MTP layers only exist on the last stage. + # Non-last stages legitimately have 0 MTP params. + is_last_stage = True + try: + import megatron.core.parallel_state as mpu + + if mpu.is_initialized() and mpu.get_pipeline_model_parallel_world_size() > 1: + is_last_stage = mpu.is_pipeline_last_stage() + except Exception: + pass + if mtp_param_count == 0: - self.logger.error( - "[MTPTrain] enable_mtp_training=True but NO MTP parameters found in model! " - "Possible causes: 1) mtp_num_layers=0 in model config; " - "2) Model checkpoint does not contain MTP layers; " - "3) mcore_config.mtp_num_layers not set correctly. " - "MTP loss will NOT be computed." - ) + if not is_last_stage: + self._mtp_layers_verified = True + self.logger.info( + "[MTPTrain] This rank is NOT on the last pipeline stage; " + "MTP parameters are expected only on the last stage. " + "Skipping MTP param verification on this rank." + ) + else: + self.logger.error( + "[MTPTrain] enable_mtp_training=True but NO MTP parameters found " + "on the LAST pipeline stage! " + "Possible causes: 1) mtp_num_layers=0 in model config; " + "2) Model checkpoint does not contain MTP layers; " + "3) mbridge did not pass mtp_block_spec to GPTModel. " + "MTP loss will NOT be computed." + ) else: self._mtp_layers_verified = True self.logger.info( f"[MTPTrain] Verified MTP parameters in model: " - f"total_mtp_params={mtp_param_count / 1e6:.2f}M" + f"total_mtp_params={mtp_param_count / 1e6:.2f}M, " + f"sample_params={mtp_param_names}" ) self._initialized = True diff --git a/areal/models/mcore/registry.py b/areal/models/mcore/registry.py index 2240c03c7b..2825daef0f 100644 --- a/areal/models/mcore/registry.py +++ b/areal/models/mcore/registry.py @@ -268,6 +268,105 @@ def make_mcore_layer_specs(hf_config: PretrainedConfig, tf_config: TransformerCo ) +def _ensure_mtp_in_gpt_model(log): + """Monkey-patch GPTModel.__init__ to auto-inject mtp_block_spec when mbridge + doesn't pass it but config.mtp_num_layers > 0. + + mbridge calls bridge.get_model() which internally creates GPTModel, but does NOT + pass mtp_block_spec even when config.mtp_num_layers > 0. GPTModel checks + ``mtp_block_spec is not None`` (the constructor argument) to decide whether to + create MTP layers -- it does NOT check config.mtp_num_layers. + + This patch intercepts GPTModel.__init__ and, when mtp_block_spec is missing but + config indicates MTP should be used, resolves the spec and injects it. + + Returns a callable that restores the original __init__. + """ + from megatron.core.models.gpt import GPTModel + + _original_init = GPTModel.__init__ + + def _patched_init(self, *args, **kwargs): + config = kwargs.get("config", args[0] if args else None) + mtp_block_spec = kwargs.get("mtp_block_spec", None) + + if ( + mtp_block_spec is None + and config is not None + and getattr(config, "mtp_num_layers", 0) > 0 + ): + log.info( + "[MTPTrain] GPTModel.__init__ intercepted: mtp_block_spec is None " + f"but config.mtp_num_layers={config.mtp_num_layers}. " + "Auto-resolving mtp_block_spec..." + ) + try: + tls = kwargs.get("transformer_layer_spec", None) + if tls is None and len(args) > 1: + tls = args[1] + + spec_for_mtp = None + if tls is not None: + from megatron.core.transformer.spec_utils import ModuleSpec + from megatron.core.transformer.transformer_block import ( + TransformerBlockSubmodules, + ) + + if isinstance(tls, ModuleSpec): + submodules = getattr(tls, "submodules", None) + if isinstance(submodules, TransformerBlockSubmodules): + layers = getattr(submodules, "layer_specs", None) + if layers and len(layers) > 0: + spec_for_mtp = layers[-1] + log.info( + f"[MTPTrain] Extracted layer spec from " + f"TransformerBlockSubmodules (n_layers={len(layers)})" + ) + elif isinstance(tls, TransformerBlockSubmodules): + layers = getattr(tls, "layer_specs", None) + if layers and len(layers) > 0: + spec_for_mtp = layers[-1] + + if spec_for_mtp is None: + log.warning( + "[MTPTrain] Could not extract layer spec from " + "transformer_layer_spec; falling back to " + "get_gpt_mtp_block_spec(config)." + ) + from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_mtp_block_spec, + ) + resolved_spec = get_gpt_mtp_block_spec(config) + else: + from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_mtp_block_spec, + ) + resolved_spec = get_gpt_mtp_block_spec(config, spec_for_mtp) + + kwargs["mtp_block_spec"] = resolved_spec + log.info( + f"[MTPTrain] Injected mtp_block_spec into GPTModel.__init__: " + f"type={type(resolved_spec).__name__}" + ) + except Exception as e: + log.error( + f"[MTPTrain] Failed to auto-resolve mtp_block_spec: {e}. " + "MTP layers will NOT be created.", + exc_info=True, + ) + + return _original_init(self, *args, **kwargs) + + GPTModel.__init__ = _patched_init + log.info("[MTPTrain] GPTModel.__init__ monkey-patched to auto-inject mtp_block_spec.") + + def _restore(): + GPTModel.__init__ = _original_init + log.info("[MTPTrain] GPTModel.__init__ restored to original.") + + return _restore + + def make_mcore_model( hf_config: PretrainedConfig, tf_config: TransformerConfig, @@ -282,24 +381,30 @@ def make_mcore_model( # Patch get_gpt_mtp_block_spec before mbridge calls it so that a # TransformerConfig passed as ``spec`` is auto-converted to the # correct ModuleSpec type expected by megatron-core. + _restore_mtp_inject = None if enable_mtp: _ensure_mtp_spec_compat() logger.info( "[MTPTrain] Applied MTP spec compatibility patch before mbridge model creation." ) - - models = bridge.get_model( - # TODO: Add DDP options when supporting training - wrap_with_ddp=mcore_config.wrap_with_ddp, - ddp_config=dataclasses.asdict(mcore_config.ddp), - use_torch_fsdp2=mcore_config.use_torch_fsdp2, - use_custom_fsdp=mcore_config.use_custom_fsdp, - fp16=tf_config.fp16, - bf16=tf_config.bf16, - use_precision_aware_optimizer=mcore_config.use_precision_aware_optimizer, - overlap_param_gather_with_optimizer_step=mcore_config.overlap_param_gather_with_optimizer_step, - ) - models = list(models) + _restore_mtp_inject = _ensure_mtp_in_gpt_model(logger) + + try: + models = bridge.get_model( + # TODO: Add DDP options when supporting training + wrap_with_ddp=mcore_config.wrap_with_ddp, + ddp_config=dataclasses.asdict(mcore_config.ddp), + use_torch_fsdp2=mcore_config.use_torch_fsdp2, + use_custom_fsdp=mcore_config.use_custom_fsdp, + fp16=tf_config.fp16, + bf16=tf_config.bf16, + use_precision_aware_optimizer=mcore_config.use_precision_aware_optimizer, + overlap_param_gather_with_optimizer_step=mcore_config.overlap_param_gather_with_optimizer_step, + ) + models = list(models) + finally: + if _restore_mtp_inject is not None: + _restore_mtp_inject() # Replace output_layer with ValueHead for critic models if is_critic: From bd36e6a2f29827cc147b9f7c36b724a84daa39a9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 18:34:54 +0800 Subject: [PATCH 015/140] feat: remove --- areal/engine/megatron_engine.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 36fdfd76bf..e82c013004 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -461,8 +461,6 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): # Non-last stages legitimately have 0 MTP params. is_last_stage = True try: - import megatron.core.parallel_state as mpu - if mpu.is_initialized() and mpu.get_pipeline_model_parallel_world_size() > 1: is_last_stage = mpu.is_pipeline_last_stage() except Exception: From 4924ee42e2822ddcf0bfc33390082a1789febd3e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 18:37:32 +0800 Subject: [PATCH 016/140] feat: change config --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 1f126a21a1..eb22e62240 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -19,7 +19,7 @@ scheduler: type: null rollout: - backend: "sglang:d1p1t2" + backend: "sglang:d1p1t8" experiment_name: ${experiment_name} trial_name: ${trial_name} max_concurrent_rollouts: 128 @@ -40,7 +40,7 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:d1p1t6" + backend: "megatron:d1p1t8" experiment_name: ${experiment_name} trial_name: ${trial_name} path: /workspace/models/MiMo-7B-RL From 9f7796bc6dbcf6015f9023dd8fb32ef540e8e2cb Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 19:04:51 +0800 Subject: [PATCH 017/140] feat: improve mtp_loss_scaling_factor --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index eb22e62240..8b86e44a82 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -81,7 +81,7 @@ actor: # MTP Online Training - keeps EAGLE draft heads aligned with evolving policy enable_mtp_training: true mtp_num_layers: 1 - mtp_loss_scaling_factor: 0.1 + mtp_loss_scaling_factor: 0.2 scheduling_spec: - task_type: worker From e0471c366e471b854726e170c5e353b3c67396e0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 19:39:30 +0800 Subject: [PATCH 018/140] feat: fix config --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 8b86e44a82..f0d422fd8c 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -8,16 +8,17 @@ tokenizer_path: ${actor.path} cluster: n_nodes: 1 - n_gpus_per_node: 6 + n_gpus_per_node: 8 fileroot: /tmp/areal/experiments name_resolve: type: nfs nfs_record_root: /tmp/areal/name_resolve - scheduler: type: null +allocation_mode: "sglang[rollout]:d1p1t8|megatron[actor]:d1p1t8" + rollout: backend: "sglang:d1p1t8" experiment_name: ${experiment_name} From 0bc4616bdbe14ab41723fab8ad92a34575d79c88 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 20:38:33 +0800 Subject: [PATCH 019/140] feat: fix --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index f0d422fd8c..478eeaaf76 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -17,10 +17,8 @@ cluster: scheduler: type: null -allocation_mode: "sglang[rollout]:d1p1t8|megatron[actor]:d1p1t8" - rollout: - backend: "sglang:d1p1t8" + backend: "sglang:d1p1t4" experiment_name: ${experiment_name} trial_name: ${trial_name} max_concurrent_rollouts: 128 @@ -41,7 +39,7 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:d1p1t8" + backend: "megatron:d1p1t4" experiment_name: ${experiment_name} trial_name: ${trial_name} path: /workspace/models/MiMo-7B-RL @@ -120,7 +118,7 @@ sglang: dtype: ${actor.dtype} max_running_requests: null context_length: 32768 - mem_fraction_static: 0.4 + mem_fraction_static: 0.5 # EAGLE speculative decoding settings speculative_algorithm: "EAGLE" From 031fa897303ff606e40f2b650cdf8430212bf0e5 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 21:44:03 +0800 Subject: [PATCH 020/140] feat: fix local for test --- areal/infra/launcher/local.py | 16 ++++++++++++++++ examples/math/gsm8k_grpo_megatron_mimo.yaml | 6 ++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/areal/infra/launcher/local.py b/areal/infra/launcher/local.py index c92cba902e..1ceb9e6819 100644 --- a/areal/infra/launcher/local.py +++ b/areal/infra/launcher/local.py @@ -378,6 +378,7 @@ def local_main(config, run_id: int = 0): tms_env_vars = {} # Launch trainer entrypoint if alloc_mode.type_ != AllocationType.LLM_SERVER_ONLY: + is_colocate = alloc_mode.type_ == AllocationType.COLOCATE gpu = nprocs = alloc_mode.train.world_size _env_vars = dict( AREAL_LLM_SERVER_ADDRS=",".join(server_addrs), @@ -394,6 +395,21 @@ def local_main(config, run_id: int = 0): cpus_per_task=actor_cpus_per_task, existing_env_vars=actor_env_vars, ) + if is_colocate: + # In colocation mode, trainer reuses the same GPUs as the + # inference server. We roll back the GPU counter so that + # submit_array assigns the same CUDA_VISIBLE_DEVICES that was + # given to the llm_server job. + gen_gpu = ( + alloc_mode.gen.dp_size + * alloc_mode.gen.pp_size + * alloc_mode.gen.tp_size + ) + launcher._gpu_counter = max(0, launcher._gpu_counter - gen_gpu) + logger.info( + f"[Colocation] Trainer will share {gen_gpu} GPUs with the inference server. " + f"GPU counter rolled back to {launcher._gpu_counter}." + ) launcher.submit( job_name="trainer", cmd=f"torchrun --nnodes 1 --nproc-per-node {nprocs} --master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} {' '.join(sys.argv[1:])}", diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 478eeaaf76..94309cd8a3 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -17,8 +17,10 @@ cluster: scheduler: type: null +allocation_mode: "sglang[rollout]:d1p1t8|megatron[actor]:d1p1t8" + rollout: - backend: "sglang:d1p1t4" + backend: "sglang:d1p1t8" experiment_name: ${experiment_name} trial_name: ${trial_name} max_concurrent_rollouts: 128 @@ -39,7 +41,7 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:d1p1t4" + backend: "megatron:d1p1t8" experiment_name: ${experiment_name} trial_name: ${trial_name} path: /workspace/models/MiMo-7B-RL From b36ca69c247c488e683c3cb0dac13a40376205f3 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 22:44:09 +0800 Subject: [PATCH 021/140] feat: bug fix --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 94309cd8a3..097bb20e92 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -8,17 +8,16 @@ tokenizer_path: ${actor.path} cluster: n_nodes: 1 - n_gpus_per_node: 8 + n_gpus_per_node: 6 fileroot: /tmp/areal/experiments name_resolve: type: nfs nfs_record_root: /tmp/areal/name_resolve + scheduler: type: null -allocation_mode: "sglang[rollout]:d1p1t8|megatron[actor]:d1p1t8" - rollout: backend: "sglang:d1p1t8" experiment_name: ${experiment_name} @@ -36,7 +35,7 @@ rollout: gconfig: n_samples: 4 min_new_tokens: 0 - max_new_tokens: 1024 + max_new_tokens: 2048 greedy: false temperature: 1.0 @@ -120,7 +119,7 @@ sglang: dtype: ${actor.dtype} max_running_requests: null context_length: 32768 - mem_fraction_static: 0.5 + mem_fraction_static: 0.4 # EAGLE speculative decoding settings speculative_algorithm: "EAGLE" From a07dc22b30cdbd6b1616d58482ee34745a5b6502 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 14 Apr 2026 23:23:46 +0800 Subject: [PATCH 022/140] feat: fix --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 097bb20e92..05db882862 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -8,7 +8,7 @@ tokenizer_path: ${actor.path} cluster: n_nodes: 1 - n_gpus_per_node: 6 + n_gpus_per_node: 8 fileroot: /tmp/areal/experiments name_resolve: type: nfs @@ -19,10 +19,10 @@ scheduler: type: null rollout: - backend: "sglang:d1p1t8" + backend: "sglang:d2p1t2" experiment_name: ${experiment_name} trial_name: ${trial_name} - max_concurrent_rollouts: 128 + max_concurrent_rollouts: 256 queue_size: null consumer_batch_size: ${train_dataset.batch_size} max_head_offpolicyness: 2 @@ -40,13 +40,13 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:d1p1t8" + backend: "megatron:d1p1t4" experiment_name: ${experiment_name} trial_name: ${trial_name} path: /workspace/models/MiMo-7B-RL init_from_scratch: false disable_dropout: true - gradient_checkpointing: false + gradient_checkpointing: true dtype: bfloat16 mb_spec: max_tokens_per_mb: 10240 From 3a56483ffb46efa80b24bde7a1e387a6f54fc5b9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 15 Apr 2026 01:15:43 +0800 Subject: [PATCH 023/140] feat: add base config --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- .../math/gsm8k_grpo_megatron_mimo_base.yaml | 180 ++++++++++++++++++ 2 files changed, 181 insertions(+), 1 deletion(-) create mode 100644 examples/math/gsm8k_grpo_megatron_mimo_base.yaml diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 05db882862..2b7fbf41e0 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -93,7 +93,7 @@ actor: megatron: mtp_num_layers: 1 - mtp_loss_scaling_factor: 0.1 + mtp_loss_scaling_factor: 0.2 ref: backend: ${actor.backend} diff --git a/examples/math/gsm8k_grpo_megatron_mimo_base.yaml b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml new file mode 100644 index 0000000000..286721c815 --- /dev/null +++ b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml @@ -0,0 +1,180 @@ +experiment_name: gsm8k-grpo-megatron +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + + +scheduler: + type: null + +rollout: + backend: "sglang:d2p1t2" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 2048 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:d1p1t4" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /workspace/models/MiMo-7B-RL + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 3e-6 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.4 + + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + +# datasets +train_dataset: + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 128 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false From 05bdf24b632a2f0eac511b25ad585f7ee29be45e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 15 Apr 2026 01:23:27 +0800 Subject: [PATCH 024/140] feat: remove mtp keep --- areal/engine/megatron_engine.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index e82c013004..9f4aa5cbfc 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -353,6 +353,14 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): f"mtp_loss_scaling_factor={self.mtp_loss_scaling_factor}, " f"mtp_detach_heads={self.mtp_detach_heads}" ) + else: + if getattr(self.tf_config, "mtp_num_layers", 0) > 0: + self.logger.info( + f"[MTPTrain] MTP training disabled but tf_config.mtp_num_layers=" + f"{self.tf_config.mtp_num_layers}. Resetting to 0 to prevent " + f"mbridge from creating MTP layers (avoids Invalid spec error)." + ) + self.tf_config.mtp_num_layers = 0 with self.device: models = make_mcore_model( From 41ac1ab522eb7bba5d47da7d191b5eb6b259a2bd Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 15 Apr 2026 10:45:33 +0800 Subject: [PATCH 025/140] feat: fix mtp loss --- areal/engine/megatron_engine.py | 142 ++++++++++++++------ examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 2 files changed, 100 insertions(+), 44 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 9f4aa5cbfc..ff0324e7a1 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -796,7 +796,7 @@ def _collect_mtp_loss(self) -> dict[str, float]: op=torch.distributed.ReduceOp.AVG, ) - mtp_loss_value = values.item() + mtp_loss_value = values.sum().item() self._mtp_loss_value = mtp_loss_value if is_last_pp_stage: @@ -896,56 +896,109 @@ def forward_step(batch_iter, model): mb_input.padded_mb.update(tree_kwargs) tree_attn_keys = list(tree_kwargs.keys()) - # ---- MTP safety: disable in GPTModel._postprocess() ---- - # megatron-core 0.16.x GPTModel._postprocess() has known - # issues that crash both inference and training paths: + # ---- MTP handling in GPTModel._postprocess() ---- # - # 1) `if mtp_in_postprocess:` calls `self.mtp(...)`. - # `mtp_in_postprocess` comes from `self.mtp_process`, - # a bool cached in __init__ from `mtp_block_spec is - # not None`. Setting self.mtp=None is NOT enough — - # self.mtp_process must also be set to False. - # 2) `if self.config.mtp_num_layers is not None:` calls - # `labels.clone()` BEFORE the `labels is None` early - # return, crashing during inference. - # 3) When labels IS provided, _postprocess() returns loss - # (not logits), incompatible with AReaL's loss pipeline. + # megatron-core 0.16.x _postprocess() behaviour: # - # Workaround: temporarily disable all three MTP flags on - # the unwrapped GPTModel so _postprocess() skips MTP - # entirely and always returns logits. + # if mtp_in_postprocess: + # hidden_states = self.mtp(...) # MTP forward + # if config.mtp_num_layers is not None: + # + # logits = self.output_layer(hidden_states) + # if labels is None: return logits + # return compute_language_model_loss(labels, logits) + # + # Inference (forward_only=True): + # AReaL does NOT pass labels, so labels.clone() crashes. + # MTP forward is also unnecessary for logprob collection. + # → Disable MTP entirely so _postprocess returns logits. + # + # Training (forward_only=False): + # We NEED MTP loss in the autograd graph for draft-model + # training, but AReaL also needs logits (not CE loss) for + # its RL loss pipeline. Strategy: + # 1. Keep MTP enabled so _postprocess runs mtp forward + # and computes MTP loss (MTPLossAutoScaler). + # 2. Pass labels & loss_mask via extra_block_kwargs. + # 3. Monkey-patch compute_language_model_loss: the LAST + # call (main CE) returns logits instead of loss; + # earlier calls (per-MTP-layer) use real CE. extra_block_kwargs = None _mtp_restore = None + _clm_loss_restore = None if self.enable_mtp_training: _unwrapped = model while hasattr(_unwrapped, 'module'): _unwrapped = _unwrapped.module - _saved_mtp = getattr(_unwrapped, 'mtp', None) - _saved_mtp_process = getattr( - _unwrapped, 'mtp_process', None - ) - _saved_mtp_layers = getattr( - _unwrapped.config, 'mtp_num_layers', None - ) - if ( - _saved_mtp is not None - or _saved_mtp_process - or _saved_mtp_layers is not None - ): - _unwrapped.mtp = None - _unwrapped.mtp_process = False - _unwrapped.config.mtp_num_layers = None - _mtp_restore = ( - _unwrapped, - _saved_mtp, - _saved_mtp_process, - _saved_mtp_layers, + + if forward_only: + # -- Inference: disable MTP to avoid crash -- + _saved_mtp = getattr(_unwrapped, 'mtp', None) + _saved_mtp_process = getattr( + _unwrapped, 'mtp_process', None ) - self.logger.debug( - f"[MTPTrain] Temporarily disabled MTP in " - f"_postprocess (forward_only={forward_only}, " - f"saved_mtp_process={_saved_mtp_process}, " - f"saved_mtp_num_layers={_saved_mtp_layers})" + _saved_mtp_layers = getattr( + _unwrapped.config, 'mtp_num_layers', None + ) + if ( + _saved_mtp is not None + or _saved_mtp_process + or _saved_mtp_layers is not None + ): + _unwrapped.mtp = None + _unwrapped.mtp_process = False + _unwrapped.config.mtp_num_layers = None + _mtp_restore = ( + _unwrapped, + _saved_mtp, + _saved_mtp_process, + _saved_mtp_layers, + ) + self.logger.info( + "[MTPTrain] Disabled MTP in _postprocess for " + "inference (forward_only=True)" + ) + else: + # -- Training: enable MTP with labels & loss_mask -- + # Construct causal-LM labels from padded input_ids. + _input_ids = mb_input.padded_mb["input_ids"] + _mtp_labels = torch.roll( + _input_ids, shifts=-1, dims=-1 + ) + # loss_mask carried through pack/pad pipeline; + # fall back to None → megatron uses ones_like. + _mtp_loss_mask = mb_input.padded_mb.get( + "loss_mask", None + ) + extra_block_kwargs = {"labels": _mtp_labels} + if _mtp_loss_mask is not None: + extra_block_kwargs["loss_mask"] = _mtp_loss_mask + + # Monkey-patch: make the LAST call to + # compute_language_model_loss (the main CE loss) + # return logits so AReaL gets logits, not loss. + _remaining = [self.mtp_num_layers] + _orig_clm = _unwrapped.compute_language_model_loss + + def _mtp_loss_fn( + _labels, _logits, + _rem=_remaining, _orig=_orig_clm, + ): + if _rem[0] > 0: + _rem[0] -= 1 + return _orig(_labels, _logits) + # Return logits in [b, s, v] matching the + # ``if labels is None`` path in _postprocess. + return _logits.transpose(0, 1).contiguous() + + _unwrapped.compute_language_model_loss = _mtp_loss_fn + _clm_loss_restore = (_unwrapped, _orig_clm) + + self.logger.info( + f"[MTPTrain] MTP enabled for training " + f"(mtp_num_layers={self.mtp_num_layers}, " + f"labels_shape={_mtp_labels.shape}, " + f"loss_mask={'yes' if _mtp_loss_mask is not None else 'no'})" ) try: @@ -960,8 +1013,11 @@ def forward_step(batch_iter, model): _uw.mtp_process = _sp _uw.config.mtp_num_layers = _sl self.logger.debug( - "[MTPTrain] Restored MTP after forward pass" + "[MTPTrain] Restored MTP after inference forward" ) + if _clm_loss_restore is not None: + _uw, _orig = _clm_loss_restore + _uw.compute_language_model_loss = _orig # Release tree attention metadata after forward pass for key in tree_attn_keys: diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 2b7fbf41e0..1b528d2fbf 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -117,7 +117,7 @@ sglang: random_seed: ${seed} skip_tokenizer_init: true dtype: ${actor.dtype} - max_running_requests: null + max_running_requests: 96 context_length: 32768 mem_fraction_static: 0.4 From dd4eb5ec6ac7d424886fae2af580ea566c1283d7 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 15 Apr 2026 11:53:39 +0800 Subject: [PATCH 026/140] feat: fix ckpt --- areal/engine/megatron_engine.py | 128 +++++++++++++++++++++++++++++++- 1 file changed, 127 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ff0324e7a1..1e10d0fe0d 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -926,6 +926,7 @@ def forward_step(batch_iter, model): extra_block_kwargs = None _mtp_restore = None _clm_loss_restore = None + _mtp_ckpt_restore = [] # (layer, orig_method) pairs if self.enable_mtp_training: _unwrapped = model while hasattr(_unwrapped, 'module'): @@ -994,7 +995,130 @@ def _mtp_loss_fn( _unwrapped.compute_language_model_loss = _mtp_loss_fn _clm_loss_restore = (_unwrapped, _orig_clm) - self.logger.info( + # ----------------------------------------------------------- + # Megatron-Core 0.16.0 MTP _checkpointed_forward() does: + # tensor_parallel.checkpoint(fn, ..., *args, *kwargs.values()) + # This flattens ALL kwargs (including packed_seq_params which + # is a dataclass, not a tensor) into positional args that end + # up in CheckpointFunction.apply() → save_for_backward(), + # which only accepts tensors → TypeError. + # + # The main TransformerBlock avoids this by capturing + # packed_seq_params via closure (never passed as an arg). + # We apply the same pattern here by monkey-patching each + # MTP layer's _checkpointed_forward during training. + # ----------------------------------------------------------- + _mtp_block = getattr(_unwrapped, 'mtp', None) + if ( + _mtp_block is not None + and hasattr(_mtp_block, 'layers') + and _unwrapped.config.recompute_granularity == 'full' + ): + for _layer in _mtp_block.layers: + _orig_ckpt_fwd = _layer._checkpointed_forward + + def _patched_checkpointed_forward( + forward_func, *args, + _layer_ref=_layer, + **kwargs, + ): + """Closure-based checkpoint that keeps + non-tensor args (packed_seq_params, + inference_params) out of save_for_backward. + + Mirrors TransformerBlock._checkpointed_forward + from megatron-core 0.16.0: non-tensor kwargs + are captured in the closure of custom_forward, + only tensor values go through checkpoint(). + """ + # Separate tensor vs non-tensor kwargs. + _tensor_kw = {} + _non_tensor_kw = {} + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + _tensor_kw[k] = v + else: + _non_tensor_kw[k] = v + + # Build a wrapper that re-injects non-tensor + # kwargs via closure (never saved by + # checkpoint). + def _ckpt_wrapper(*flat_args): + # Reconstruct kwargs: first the tensor + # ones from flat_args, then non-tensor + # from closure. + _tk_keys = list(_tensor_kw.keys()) + # flat_args = original *args + tensor kw + # values in order. + n_orig = len(args) + _orig_args = flat_args[:n_orig] + _tk_vals = flat_args[n_orig:] + _rebuilt_kw = { + k: v for k, v in zip( + _tk_keys, _tk_vals + ) + } + _rebuilt_kw.update(_non_tensor_kw) + return forward_func( + *_orig_args, **_rebuilt_kw + ) + + _cfg = _layer_ref.config + if _cfg.recompute_method == 'uniform': + assert ( + _cfg.recompute_num_layers == 1 + ), ( + "recompute_num_layers must be 1 " + "for MTP recompute" + ) + if _cfg.fp8: + from megatron.core.extensions.transformer_engine import ( + te_checkpoint, + ) + return te_checkpoint( + _ckpt_wrapper, + _cfg.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + *args, + *_tensor_kw.values(), + ) + else: + return tensor_parallel.checkpoint( + _ckpt_wrapper, + _cfg.distribute_saved_activations, + *args, + *_tensor_kw.values(), + ) + elif _cfg.recompute_method == 'block': + import warnings + warnings.warn( + "recompute_method == 'block' is not " + "supported for MTP yet. " + "Skipping recompute." + ) + return forward_func(*args, **kwargs) + else: + raise ValueError( + "Invalid activation recompute method." + ) + + _layer._checkpointed_forward = ( + _patched_checkpointed_forward + ) + _mtp_ckpt_restore.append( + (_layer, _orig_ckpt_fwd) + ) + + self.logger.info( + f"[MTPTrain] Patched _checkpointed_forward on " + f"{len(_mtp_ckpt_restore)} MTP layer(s) to fix " + f"gradient_checkpointing + PackedSeqParams crash " + f"(recompute_granularity=" + f"{_unwrapped.config.recompute_granularity})" + ) + + self.logger.debug( f"[MTPTrain] MTP enabled for training " f"(mtp_num_layers={self.mtp_num_layers}, " f"labels_shape={_mtp_labels.shape}, " @@ -1018,6 +1142,8 @@ def _mtp_loss_fn( if _clm_loss_restore is not None: _uw, _orig = _clm_loss_restore _uw.compute_language_model_loss = _orig + for _layer, _orig_ckpt in _mtp_ckpt_restore: + _layer._checkpointed_forward = _orig_ckpt # Release tree attention metadata after forward pass for key in tree_attn_keys: From 7c2754420d61d85e303d0b7b96a7312cf037e896 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 15 Apr 2026 15:47:41 +0800 Subject: [PATCH 027/140] feat: fix config oom --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 1b528d2fbf..43f23cc9f6 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -19,7 +19,7 @@ scheduler: type: null rollout: - backend: "sglang:d2p1t2" + backend: "sglang:d1p1t2" experiment_name: ${experiment_name} trial_name: ${trial_name} max_concurrent_rollouts: 256 @@ -51,7 +51,7 @@ actor: mb_spec: max_tokens_per_mb: 10240 optimizer: - type: adam + type: adam_bf16 lr: 3e-6 weight_decay: 0.003 beta1: 0.9 From f719dadaecaf1c4d6dbdc441883f495bf52e78d4 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 15 Apr 2026 17:38:21 +0800 Subject: [PATCH 028/140] feat: fix no mtp --- areal/engine/megatron_engine.py | 58 ++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 1e10d0fe0d..d621a08dd3 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -354,13 +354,21 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): f"mtp_detach_heads={self.mtp_detach_heads}" ) else: - if getattr(self.tf_config, "mtp_num_layers", 0) > 0: + # When MTP training is disabled, clear mtp_num_layers to + # prevent mbridge from creating MTP layers. Without this, + # models like MiMo whose HF config contains + # num_nextn_predict_layers>0 would still create MTP layers + # through mbridge, causing _postprocess() to enter the MTP + # loss path and crash on labels.clone() when labels is None + # during inference. + _orig_mtp = getattr(self.tf_config, "mtp_num_layers", None) + if _orig_mtp is not None and _orig_mtp > 0: + self.tf_config.mtp_num_layers = None self.logger.info( - f"[MTPTrain] MTP training disabled but tf_config.mtp_num_layers=" - f"{self.tf_config.mtp_num_layers}. Resetting to 0 to prevent " - f"mbridge from creating MTP layers (avoids Invalid spec error)." + f"[MTPConfig] Cleared tf_config.mtp_num_layers " + f"(was {_orig_mtp}) because enable_mtp_training=False. " + f"MTP layers will NOT be created in GPTModel." ) - self.tf_config.mtp_num_layers = 0 with self.device: models = make_mcore_model( @@ -927,6 +935,46 @@ def forward_step(batch_iter, model): _mtp_restore = None _clm_loss_restore = None _mtp_ckpt_restore = [] # (layer, orig_method) pairs + + # Defensive guard: even when enable_mtp_training=False, the + # model may still have MTP artefacts (e.g. config.mtp_num_layers + # leaked from HF/mbridge config, or MTP layers loaded from a + # checkpoint). During inference this causes _postprocess() to + # enter the MTP loss path and crash on labels.clone() when + # labels is None. Disable MTP at runtime in this case. + if not self.enable_mtp_training and forward_only: + _unwrapped_def = model + while hasattr(_unwrapped_def, 'module'): + _unwrapped_def = _unwrapped_def.module + _def_mtp = getattr(_unwrapped_def, 'mtp', None) + _def_mtp_process = getattr( + _unwrapped_def, 'mtp_process', False + ) + _def_mtp_layers = getattr( + _unwrapped_def.config, 'mtp_num_layers', None + ) + if ( + _def_mtp is not None + or _def_mtp_process + or _def_mtp_layers is not None + ): + _unwrapped_def.mtp = None + _unwrapped_def.mtp_process = False + _unwrapped_def.config.mtp_num_layers = None + _mtp_restore = ( + _unwrapped_def, + _def_mtp, + _def_mtp_process, + _def_mtp_layers, + ) + self.logger.debug( + f"[MTPGuard] Disabled MTP for inference " + f"(enable_mtp_training=False but model had " + f"mtp={_def_mtp is not None}, " + f"mtp_process={_def_mtp_process}, " + f"mtp_num_layers={_def_mtp_layers})" + ) + if self.enable_mtp_training: _unwrapped = model while hasattr(_unwrapped, 'module'): From 1bcabea534754df8beb6f2dec2166686de900483 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 15 Apr 2026 17:41:48 +0800 Subject: [PATCH 029/140] feat: fix config --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- examples/math/gsm8k_grpo_megatron_mimo_base.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 43f23cc9f6..06eab1fbf9 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -117,7 +117,7 @@ sglang: random_seed: ${seed} skip_tokenizer_init: true dtype: ${actor.dtype} - max_running_requests: 96 + max_running_requests: null context_length: 32768 mem_fraction_static: 0.4 diff --git a/examples/math/gsm8k_grpo_megatron_mimo_base.yaml b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml index 286721c815..5d115eed35 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo_base.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml @@ -19,7 +19,7 @@ scheduler: type: null rollout: - backend: "sglang:d2p1t2" + backend: "sglang:d1p1t2" experiment_name: ${experiment_name} trial_name: ${trial_name} max_concurrent_rollouts: 256 @@ -51,7 +51,7 @@ actor: mb_spec: max_tokens_per_mb: 10240 optimizer: - type: adam + type: adam_bf16 lr: 3e-6 weight_decay: 0.003 beta1: 0.9 From 107b8f35cd888aeb281f2b806898976db6ed1e92 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 15 Apr 2026 18:36:41 +0800 Subject: [PATCH 030/140] feat: fix --- examples/math/gsm8k_grpo_megatron_mimo_base.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo_base.yaml b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml index 5d115eed35..3ad3390bbc 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo_base.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml @@ -49,7 +49,7 @@ actor: gradient_checkpointing: true dtype: bfloat16 mb_spec: - max_tokens_per_mb: 10240 + max_tokens_per_mb: 5120 optimizer: type: adam_bf16 lr: 3e-6 @@ -95,7 +95,7 @@ ref: disable_dropout: true dtype: ${actor.dtype} mb_spec: - max_tokens_per_mb: 10240 + max_tokens_per_mb: 5120 optimizer: null scheduling_strategy: type: colocation @@ -110,7 +110,7 @@ sglang: dtype: ${actor.dtype} max_running_requests: null context_length: 32768 - mem_fraction_static: 0.4 + mem_fraction_static: 0.7 vllm: From 7cb5e49f7f180fbf8b438da9b213d7e48ec88e0c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 15 Apr 2026 21:44:52 +0800 Subject: [PATCH 031/140] feat: fix --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 6 +++--- examples/math/gsm8k_grpo_megatron_mimo_base.yaml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 06eab1fbf9..b6cdb3a2ca 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -25,7 +25,7 @@ rollout: max_concurrent_rollouts: 256 queue_size: null consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 2 + max_head_offpolicyness: 0 enable_rollout_tracing: false scheduling_spec: ${actor.scheduling_spec} fileroot: ${cluster.fileroot} @@ -49,7 +49,7 @@ actor: gradient_checkpointing: true dtype: bfloat16 mb_spec: - max_tokens_per_mb: 10240 + max_tokens_per_mb: 5120 optimizer: type: adam_bf16 lr: 3e-6 @@ -104,7 +104,7 @@ ref: disable_dropout: true dtype: ${actor.dtype} mb_spec: - max_tokens_per_mb: 10240 + max_tokens_per_mb: 5120 optimizer: null scheduling_strategy: type: colocation diff --git a/examples/math/gsm8k_grpo_megatron_mimo_base.yaml b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml index 3ad3390bbc..c5399fcf91 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo_base.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml @@ -25,7 +25,7 @@ rollout: max_concurrent_rollouts: 256 queue_size: null consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 2 + max_head_offpolicyness: 0 enable_rollout_tracing: false scheduling_spec: ${actor.scheduling_spec} fileroot: ${cluster.fileroot} From 7cbfe81f7cbca16f127d2324ea7782baaa1bb6d9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 16 Apr 2026 10:27:51 +0800 Subject: [PATCH 032/140] feat: add qwen --- .../math/gsm8k_grpo_megatron_qwen_base.yaml | 180 ++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 examples/math/gsm8k_grpo_megatron_qwen_base.yaml diff --git a/examples/math/gsm8k_grpo_megatron_qwen_base.yaml b/examples/math/gsm8k_grpo_megatron_qwen_base.yaml new file mode 100644 index 0000000000..b24ca1c0ea --- /dev/null +++ b/examples/math/gsm8k_grpo_megatron_qwen_base.yaml @@ -0,0 +1,180 @@ +experiment_name: gsm8k-grpo-megatron +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + + +scheduler: + type: null + +rollout: + backend: "sglang:d1p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 0 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 2048 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:d1p1t4" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /workspace/models/Qwen3.5-4B + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 5120 + optimizer: + type: adam_bf16 + lr: 3e-6 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 5120 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.7 + + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + +# datasets +train_dataset: + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 128 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false From 0a3206611805f05cfd4aa58b86fd7ed3ad215f4d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 16 Apr 2026 14:21:12 +0800 Subject: [PATCH 033/140] feat: remove enable_draft_weights_cpu_backup --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index b6cdb3a2ca..196c7dbee8 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -128,7 +128,7 @@ sglang: speculative_eagle_topk: 1 speculative_num_draft_tokens: 4 speculative_attention_mode: null - enable_draft_weights_cpu_backup: true + enable_draft_weights_cpu_backup: false vllm: From 93baaf8c3cb90b8922085b40e6250a70d14e6f31 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 16 Apr 2026 16:23:18 +0800 Subject: [PATCH 034/140] feat: fix mtp loss --- areal/engine/megatron_engine.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index d621a08dd3..52207307c4 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -815,15 +815,14 @@ def _collect_mtp_loss(self) -> dict[str, float]: f"[MTPTrain] MTP loss is NaN/Inf! value={mtp_loss_value}. " f"Check MTP label construction and model configuration." ) - elif mtp_loss_value < 0 or mtp_loss_value > 100: - self.logger.warning( - f"[MTPTrain] MTP loss {mtp_loss_value:.6f} outside expected range [0, 100]." - ) else: + # Note: mtp_loss_value is the SUM of per-micro-batch + # average MTP losses (accumulated via += in the tracker). + # This is by design in Megatron-Core. For N micro-batches + # the value ≈ N * per_token_mtp_loss. self.logger.info( - f"[MTPTrain] MTP loss={mtp_loss_value:.6f}, " + f"[MTPTrain] MTP loss (accumulated)={mtp_loss_value:.6f}, " f"scaling_factor={self.mtp_loss_scaling_factor}, " - f"scaled_mtp_loss={mtp_loss_value * self.mtp_loss_scaling_factor:.6f}, " f"is_last_pp_stage={is_last_pp_stage}" ) @@ -1009,11 +1008,13 @@ def forward_step(batch_iter, model): ) else: # -- Training: enable MTP with labels & loss_mask -- - # Construct causal-LM labels from padded input_ids. + # Pass raw input_ids as MTP labels (NOT pre-shifted). + # Megatron-Core _postprocess() calls roll_tensor(labels, -1) + # internally for each MTP layer, so MTP layer k predicts + # token at position i+(k+1). This matches the slime + # implementation which passes batch["tokens"] directly. _input_ids = mb_input.padded_mb["input_ids"] - _mtp_labels = torch.roll( - _input_ids, shifts=-1, dims=-1 - ) + _mtp_labels = _input_ids # loss_mask carried through pack/pad pipeline; # fall back to None → megatron uses ones_like. _mtp_loss_mask = mb_input.padded_mb.get( @@ -2031,6 +2032,8 @@ def _save_model_to_hf( base_model_path: str | None = None, ) -> None: assert self.model is not None, "Model is not initialized." + gc.collect() + current_platform.empty_cache() os.makedirs(path, exist_ok=True) if self.bridge_cls == "megatron-bridge": From cdeebcbfb2a538d8d3b925b46543ba8544545c98 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 16 Apr 2026 17:11:45 +0800 Subject: [PATCH 035/140] feat: fix OOM --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 196c7dbee8..d2451ca4d4 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -35,7 +35,7 @@ rollout: gconfig: n_samples: 4 min_new_tokens: 0 - max_new_tokens: 2048 + max_new_tokens: 1024 greedy: false temperature: 1.0 @@ -49,7 +49,7 @@ actor: gradient_checkpointing: true dtype: bfloat16 mb_spec: - max_tokens_per_mb: 5120 + max_tokens_per_mb: 2048 optimizer: type: adam_bf16 lr: 3e-6 @@ -92,6 +92,7 @@ actor: env_vars: {} megatron: + distribute_saved_activations: true mtp_num_layers: 1 mtp_loss_scaling_factor: 0.2 @@ -104,7 +105,7 @@ ref: disable_dropout: true dtype: ${actor.dtype} mb_spec: - max_tokens_per_mb: 5120 + max_tokens_per_mb: 2048 optimizer: null scheduling_strategy: type: colocation From f79b5c50a5ba13e6929f84e3bd86fce28552d18b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 16 Apr 2026 17:16:36 +0800 Subject: [PATCH 036/140] feat: revert --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index d2451ca4d4..01c8f69219 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -92,7 +92,6 @@ actor: env_vars: {} megatron: - distribute_saved_activations: true mtp_num_layers: 1 mtp_loss_scaling_factor: 0.2 From f5535720265096f1c14676b0e5d628265174afe6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 16 Apr 2026 17:52:10 +0800 Subject: [PATCH 037/140] feat: add mem log --- areal/engine/megatron_engine.py | 12 ++++++++++++ areal/trainer/rl_trainer.py | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 52207307c4..abdd968989 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1240,9 +1240,11 @@ def train_batch( ) -> dict[str, float]: self._ensure_ready() self.optimizer_zero_grad() + DeviceRuntimeInfo.get_current().log("train_batch after zero_grad") # Step 1: Prepare micro-batches mb_list = self._prepare_mb_list(input_).to(self.device) + DeviceRuntimeInfo.get_current().log("train_batch after prepare_mb") # Step 2: Compute total loss weight total_loss_weight = compute_total_loss_weight( @@ -1267,6 +1269,7 @@ def process_output( ) self.forward_backward_batch(mb_list, process_output, forward_only=False) + DeviceRuntimeInfo.get_current().log("train_batch after forward_backward") # Step 4: Collect MTP loss after forward-backward mtp_loss_stats = {} @@ -1275,6 +1278,7 @@ def process_output( # Step 5: Optimizer step train_stats = self.optimizer_step() + DeviceRuntimeInfo.get_current().log("train_batch after optimizer_step") # Merge MTP stats into train stats if mtp_loss_stats: @@ -1344,6 +1348,7 @@ def forward_batch( # Step 2: Prepare micro-batches mb_list = self._prepare_mb_list(input_).to(self.device) + DeviceRuntimeInfo.get_current().log("forward_batch after prepare_mb") # Step 3: Forward using Megatron's pipeline function, collecting results outputs: list[torch.Tensor] = [] @@ -1354,6 +1359,7 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None: return None self.forward_backward_batch(mb_list, process_output, forward_only=True) + DeviceRuntimeInfo.get_current().log("forward_batch after forward_backward") # Step 4: Aggregate, reorder, and broadcast outputs res = None @@ -1912,6 +1918,7 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta) -> None: @trace_perf("megatron_engine.update_weights_from_distributed", category="comm") def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: + DeviceRuntimeInfo.get_current().log("_update_weights_from_distributed start") # Reset weight weight meta with local info meta.nccl_master_address = self.weight_update_master_addr meta.nccl_master_port = self.weight_update_master_port @@ -2001,6 +2008,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: @trace_perf("megatron_engine.update_weights_from_disk", category="io") def _update_weights_from_disk(self, meta: WeightUpdateMeta) -> None: + DeviceRuntimeInfo.get_current().log("_update_weights_from_disk start") fut = Future() if dist.get_rank() == 0: @@ -2032,8 +2040,10 @@ def _save_model_to_hf( base_model_path: str | None = None, ) -> None: assert self.model is not None, "Model is not initialized." + DeviceRuntimeInfo.get_current().log("_save_model_to_hf before gc/empty_cache") gc.collect() current_platform.empty_cache() + DeviceRuntimeInfo.get_current().log("_save_model_to_hf after gc/empty_cache") os.makedirs(path, exist_ok=True) if self.bridge_cls == "megatron-bridge": @@ -2066,6 +2076,8 @@ def _save_model_to_hf( fp8_direct_convert=self.fp8_direct_convert, ) + DeviceRuntimeInfo.get_current().log("_save_model_to_hf after save_weights") + if dist.get_rank() == 0: if tokenizer is not None: tokenizer.save_pretrained(path) diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index d204998ba4..d8e93cbc5e 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -33,6 +33,7 @@ ValidDatasetConfig, vLLMConfig, ) +from areal.api.io_struct import DeviceRuntimeInfo from areal.engine import RemoteSGLangEngine, RemotevLLMEngine from areal.infra import ( LocalScheduler, @@ -357,6 +358,7 @@ def train( group_size=config.gconfig.n_samples, dynamic_bs=self.config.dynamic_bs, ) + DeviceRuntimeInfo.get_current().log(f"step {global_step} after rollout") if self.critic is not None: with ( @@ -385,6 +387,7 @@ def train( for traj, logp in zip(rollout_batch, prox_logps): traj["prox_logp"] = logp self.actor.get_device_stats().log("recompute logp") + DeviceRuntimeInfo.get_current().log(f"step {global_step} after recompute_logp") if self.ref is not None: with ( @@ -428,6 +431,7 @@ def train( ): adv_batch = self.actor.compute_advantages(rollout_batch) self.actor.get_device_stats().log("compute advantages") + DeviceRuntimeInfo.get_current().log(f"step {global_step} after compute_advantages") # Wait for async checkpoint staging to complete before modifying parameters self.saver.maybe_wait_for_staging() @@ -443,6 +447,7 @@ def train( self.actor.ppo_update(adv_batch) self.actor.step_lr_scheduler() self.actor.get_device_stats().log("ppo update") + DeviceRuntimeInfo.get_current().log(f"step {global_step} after ppo_update") if self.critic is not None: with ( @@ -480,6 +485,8 @@ def train( if self.eval_rollout is not None: self.eval_rollout.set_version(new_version) + DeviceRuntimeInfo.get_current().log(f"step {global_step} after update_weights") + with ( stats_tracker.record_timing("save"), perf_tracer.trace_scope( @@ -488,6 +495,7 @@ def train( args={"global_step": global_step}, ), ): + DeviceRuntimeInfo.get_current().log(f"step {global_step} before _save_hf") self._save_hf(epoch=epoch, epoch_step=step, global_step=global_step) with ( @@ -498,6 +506,7 @@ def train( args={"global_step": global_step}, ), ): + DeviceRuntimeInfo.get_current().log(f"step {global_step} before _save_recover_checkpoint") self._save_recover_checkpoint( epoch=epoch, epoch_step=step, global_step=global_step ) @@ -530,6 +539,8 @@ def train( # calling `clear_batches` once should be sufficient. self.actor.clear_batches(rollout_batch, adv_batch) + DeviceRuntimeInfo.get_current().log(f"step {global_step} after clear_batches") + with perf_tracer.trace_scope( "train.log_stats", category=Category.INSTR, From 53a3b2d5427370fdd54684c683d74c5493e943b6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 16 Apr 2026 18:24:36 +0800 Subject: [PATCH 038/140] feat: rm log --- areal/engine/megatron_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index abdd968989..3d69c5e719 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1002,7 +1002,7 @@ def forward_step(batch_iter, model): _saved_mtp_process, _saved_mtp_layers, ) - self.logger.info( + self.logger.debug( "[MTPTrain] Disabled MTP in _postprocess for " "inference (forward_only=True)" ) @@ -1159,7 +1159,7 @@ def _ckpt_wrapper(*flat_args): (_layer, _orig_ckpt_fwd) ) - self.logger.info( + self.logger.debug( f"[MTPTrain] Patched _checkpointed_forward on " f"{len(_mtp_ckpt_restore)} MTP layer(s) to fix " f"gradient_checkpointing + PackedSeqParams crash " From b4bbb1f70936071b463d2c888aabb69100412bdf Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 16 Apr 2026 18:41:30 +0800 Subject: [PATCH 039/140] feat: sample log --- areal/workflow/rlvr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index d00be231fc..dccb3a4428 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -1,3 +1,4 @@ +import random import uuid from collections.abc import Callable from typing import Any @@ -152,7 +153,7 @@ async def _collect_samples( f"2) Checking MTP layer training status; " f"3) Reducing speculative_num_steps." ) - else: + elif random.random() < 0.01: logger.info( f"[SpecDec] Accept rate: {accept_rate:.4f} " f"(accept={resp.spec_accept_token_num}, draft={resp.spec_draft_token_num})" From df7d9184229b2ef57af85b1cadde38c097230bb6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 17 Apr 2026 01:45:12 +0800 Subject: [PATCH 040/140] feat: fix --- examples/math/gsm8k_grpo_megatron_mimo_base.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo_base.yaml b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml index c5399fcf91..d79ca143cb 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo_base.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo_base.yaml @@ -35,7 +35,7 @@ rollout: gconfig: n_samples: 4 min_new_tokens: 0 - max_new_tokens: 2048 + max_new_tokens: 1024 greedy: false temperature: 1.0 @@ -49,7 +49,7 @@ actor: gradient_checkpointing: true dtype: bfloat16 mb_spec: - max_tokens_per_mb: 5120 + max_tokens_per_mb: 2048 optimizer: type: adam_bf16 lr: 3e-6 @@ -95,7 +95,7 @@ ref: disable_dropout: true dtype: ${actor.dtype} mb_spec: - max_tokens_per_mb: 5120 + max_tokens_per_mb: 2048 optimizer: null scheduling_strategy: type: colocation @@ -110,7 +110,7 @@ sglang: dtype: ${actor.dtype} max_running_requests: null context_length: 32768 - mem_fraction_static: 0.7 + mem_fraction_static: 0.6 vllm: From f50e6045605c9ef53eb39d87f53ccda4f4e55e5d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 17 Apr 2026 12:52:05 +0800 Subject: [PATCH 041/140] feat: fix mtp gradient --- areal/engine/megatron_engine.py | 428 +++++++++++++++++++++++++++++++- 1 file changed, 426 insertions(+), 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 3d69c5e719..ae4986b575 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -197,8 +197,9 @@ def __init__(self, config: TrainEngineConfig): self.logger.info( "[MTPTrain] Verified megatron-core MTP module available. " - "Gradient isolation (embedding detach + functional_call lm_head) " - "is handled internally by megatron-core MultiTokenPrediction module." + "Gradient isolation is handled by AReaL monkey-patches: " + "_MTPGradIsolator (backbone), functional_call (lm_head), " + "decoder_input.detach (embedding) when mtp_detach_heads=True." ) except ImportError: self.logger.error( @@ -826,6 +827,45 @@ def _collect_mtp_loss(self) -> dict[str, float]: f"is_last_pp_stage={is_last_pp_stage}" ) + # Log gradient norms for MTP vs non-MTP parameters + # to verify gradient isolation is working correctly. + if is_last_pp_stage and self.mtp_detach_heads: + try: + mtp_grad_sq = 0.0 + non_mtp_grad_sq = 0.0 + mtp_cnt = 0 + non_mtp_cnt = 0 + emb_grad_sq = 0.0 + lmhead_grad_sq = 0.0 + for module in self.model: + for name, param in module.named_parameters(): + if param.grad is not None: + g = param.grad.data.float().norm() ** 2 + if ".mtp." in name: + mtp_grad_sq += g.item() + mtp_cnt += 1 + else: + non_mtp_grad_sq += g.item() + non_mtp_cnt += 1 + if "embedding" in name and ".mtp." not in name: + emb_grad_sq += g.item() + if "output_layer" in name and ".mtp." not in name: + lmhead_grad_sq += g.item() + self.logger.info( + f"[MTPDetach] Gradient norms after backward: " + f"mtp_params={mtp_grad_sq**0.5:.6f} ({mtp_cnt}), " + f"non_mtp_params={non_mtp_grad_sq**0.5:.6f} ({non_mtp_cnt}), " + f"embedding={emb_grad_sq**0.5:.6f}, " + f"lm_head={lmhead_grad_sq**0.5:.6f}. " + f"All non-MTP norms should be ~GRPO scale only." + ) + mtp_stats["mtp_grad_norm"] = mtp_grad_sq ** 0.5 + mtp_stats["non_mtp_grad_norm"] = non_mtp_grad_sq ** 0.5 + except Exception as e: + self.logger.warning( + f"[MTPDetach] Failed to compute gradient norms: {e}" + ) + MTPLossLoggingHelper.clean_loss_in_tracker() else: if self.enable_mtp_training: @@ -933,6 +973,8 @@ def forward_step(batch_iter, model): extra_block_kwargs = None _mtp_restore = None _clm_loss_restore = None + _postprocess_restore = None # for _postprocess gradient isolation patch + _mtp_get_emb_restore = [] # for _get_embeddings gradient isolation patch _mtp_ckpt_restore = [] # (layer, orig_method) pairs # Defensive guard: even when enable_mtp_training=False, the @@ -1024,6 +1066,383 @@ def forward_step(batch_iter, model): if _mtp_loss_mask is not None: extra_block_kwargs["loss_mask"] = _mtp_loss_mask + # In Megatron-Core 0.16.0, MTP CE loss gradient leaks to + # backbone through 3 paths: + # + # Path 1: MTP loss → MTPLossAutoScaler → hidden_states → backbone + # MTPLossAutoScaler.apply(hidden_states, mtp_loss) attaches + # mtp_loss gradient to main model's hidden_states. + # Fix: Monkey-patch _postprocess with _MTPGradIsolator. + # + # Path 2: MTP loss → output_layer (lm_head) weights + # MTP logits use the SHARED output_layer and output_weight. + # MTP CE loss backpropagates through lm_head weights. + # Fix: Detach output_weight in _postprocess MTP loop, and + # use functional_call with detached params for output_layer. + # + # Path 3: MTP loss → embedding weights + # MTP layers call embedding(input_ids, position_ids) using + # the SHARED embedding layer. Gradient flows through + # decoder_input back to embedding weights. + # Fix: Patch _get_embeddings to detach decoder_input. + # ----------------------------------------------------------- + if self.mtp_detach_heads: + _orig_postprocess = _unwrapped._postprocess.__func__ + + class _MTPGradIsolator(torch.autograd.Function): + """Gradient isolator for MTP loss (Path 1). + + Bridges original hidden_states with MTP-wrapped + hidden_states to prevent MTP CE gradient from + flowing through MTPLossAutoScaler → backbone. + + MTP params still get gradients because + MTPLossAutoScaler.backward() sends + ones_like(mtp_loss) * scale to mtp_loss regardless + of grad_output. + """ + + @staticmethod + def forward(ctx, original_hs, mtp_wrapped_hs): + return original_hs.clone() + + @staticmethod + def backward(ctx, grad_output): + return grad_output, torch.zeros_like(grad_output) + + def _patched_postprocess( + self_model, + hidden_states, + input_ids, + position_ids, + labels, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + mtp_in_postprocess=None, + loss_mask=None, + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, + _orig_fn=_orig_postprocess, + _isolator=_MTPGradIsolator, + _logger=self.logger, + ): + """Patched _postprocess with comprehensive MTP gradient isolation. + + Identical to original except: + 1. _MTPGradIsolator after MTP loss loop (Path 1) + 2. output_weight.detach() for MTP loss computation (Path 2) + 3. functional_call with detached params for output_layer in MTP (Path 2) + Path 3 (embedding) is handled separately by _get_embeddings patch. + """ + from megatron.core.transformer.multi_token_prediction import ( + MTPLossAutoScaler, + MTPLossLoggingHelper, + ) + + in_inference_mode = ( + inference_context is not None + and not self_model.training + ) + if in_inference_mode: + assert runtime_gather_output, ( + "Inference must always gather TP logits" + ) + + output_weight = None + if self_model.share_embeddings_and_output_weights: + output_weight = ( + self_model.shared_embedding_or_output_weight() + ) + + if mtp_in_postprocess: + hidden_states = self_model.mtp( + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + embedding=self_model.embedding, + **(extra_block_kwargs or {}), + ) + + if not self_model.post_process: + return hidden_states + + if self_model.config.mtp_num_layers is not None: + from megatron.core.tensor_parallel.utils import ( + roll_tensor, + ) + + mtp_labels = labels.clone() + hidden_states_list = torch.chunk( + hidden_states, + 1 + self_model.config.mtp_num_layers, + dim=0, + ) + # === GRADIENT ISOLATION Path 1: save original === + _original_hs = hidden_states_list[0] + hidden_states = hidden_states_list[0] + if loss_mask is None: + loss_mask = torch.ones_like(mtp_labels) + + # === GRADIENT ISOLATION Path 2: detach output weight === + # Use detached output_weight for MTP loss computation + # so MTP CE loss does NOT update lm_head weights. + # The main model's logit computation (below) uses the + # ORIGINAL (non-detached) output_weight. + _mtp_output_weight = ( + output_weight.detach() + if output_weight is not None + else None + ) + + for mtp_layer_number in range( + self_model.config.mtp_num_layers + ): + # Use detached output_weight for MTP logits + # and functional_call with detached output_layer + # params to prevent gradient flow to lm_head. + _mtp_hs = hidden_states_list[ + mtp_layer_number + 1 + ] + # Path 2 fix: use functional_call with detached + # output_layer params for MTP logit computation. + _ol = self_model.output_layer + _ol_params = { + k: v.detach() + for k, v in _ol.named_parameters() + } + _ol_buffers = dict(_ol.named_buffers()) + _ol_kwargs = { + 'weight': _mtp_output_weight, + 'runtime_gather_output': runtime_gather_output, + } + mtp_logits, _ = torch.func.functional_call( + _ol, + {**_ol_params, **_ol_buffers}, + (_mtp_hs,), + _ol_kwargs, + ) + + mtp_labels, _ = roll_tensor( + mtp_labels, + shifts=-1, + dims=-1, + cp_group=self_model.cp_group, + packed_seq_params=packed_seq_params, + ) + loss_mask, num_tokens = roll_tensor( + loss_mask, + shifts=-1, + dims=-1, + cp_group=self_model.cp_group, + packed_seq_params=packed_seq_params, + ) + mtp_loss = ( + self_model.compute_language_model_loss( + mtp_labels, mtp_logits + ) + ) + mtp_loss = loss_mask * mtp_loss + if self_model.training: + from megatron.core import ( + parallel_state, + ) + + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + mtp_layer_number, + self_model.config.mtp_num_layers, + avg_group=parallel_state.get_data_parallel_group( + with_context_parallel=True + ), + ) + mtp_loss_scale = ( + self_model.config.mtp_loss_scaling_factor + / self_model.config.mtp_num_layers + ) + if self_model.config.calculate_per_token_loss: + hidden_states = MTPLossAutoScaler.apply( + hidden_states, + mtp_loss_scale * mtp_loss, + ) + else: + hidden_states = MTPLossAutoScaler.apply( + hidden_states, + mtp_loss_scale + * mtp_loss + / num_tokens, + ) + + # === GRADIENT ISOLATION Path 1: apply isolator === + hidden_states = _isolator.apply( + _original_hs, hidden_states + ) + _logger.debug( + "[MTPDetach] Applied gradient isolation in " + "_postprocess (Path 1: _MTPGradIsolator, " + "Path 2: detached output_weight + functional_call)" + ) + + # Inference last-token optimization + sequence_parallel_override = False + if ( + in_inference_mode + and inference_context.materialize_only_last_token_logits + ): + if inference_context.is_static_batching(): + hidden_states = hidden_states[-1:, :, :] + else: + if ( + self_model.output_layer.sequence_parallel + ): + from megatron.core.tensor_parallel import ( + gather_from_sequence_parallel_region, + ) + + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + group=self_model.pg_collection.tp, + ) + self_model.output_layer.sequence_parallel = ( + False + ) + sequence_parallel_override = True + hidden_states = ( + inference_context.last_token_logits( + hidden_states.squeeze(1).unsqueeze( + 0 + ) + ).unsqueeze(1) + ) + + # Main model logits: use ORIGINAL output_weight + # (non-detached) so GRPO gradient flows to lm_head. + logits, _ = self_model.output_layer( + hidden_states, + weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + + if sequence_parallel_override: + assert ( + in_inference_mode + and inference_context.is_dynamic_batching() + and inference_context.materialize_only_last_token_logits + ) + self_model.output_layer.sequence_parallel = ( + True + ) + + if labels is None: + return logits.transpose(0, 1).contiguous() + + loss = self_model.compute_language_model_loss( + labels, logits + ) + return loss + + import types + + _unwrapped._postprocess = types.MethodType( + _patched_postprocess, _unwrapped + ) + _postprocess_restore = ( + _unwrapped, + _orig_postprocess, + ) + + # === GRADIENT ISOLATION Path 3: embedding detach === + # Patch _get_embeddings on each MTP layer to detach + # decoder_input after embedding computation. + # This prevents MTP CE loss gradient from flowing to + # shared embedding weights. + # Reference: slime adds decoder_input = decoder_input.detach() + # in multi_token_prediction.py _get_embeddings(). + _mtp_block = getattr(_unwrapped, 'mtp', None) + if ( + _mtp_block is not None + and hasattr(_mtp_block, 'layers') + ): + for _layer in _mtp_block.layers: + _orig_get_emb = _layer._get_embeddings + + def _patched_get_embeddings( + input_ids, + position_ids, + embedding, + hidden_states, + packed_seq_params=None, + _orig=_orig_get_emb, + ): + """Patched _get_embeddings that detaches decoder_input. + + Prevents MTP CE loss gradient from flowing to + shared embedding weights. Also uses keep_graph=False + for hidden_states viewless tensor to break gradient + connection to backbone. + """ + result = _orig( + input_ids=input_ids, + position_ids=position_ids, + embedding=embedding, + hidden_states=hidden_states, + packed_seq_params=packed_seq_params, + ) + # result = (input_ids, position_ids, decoder_input, hidden_states) + _ids, _pos, _dec_input, _hs = result + # Detach decoder_input to prevent gradient flow to embedding + _dec_input = _dec_input.detach() + # Detach hidden_states to prevent gradient flow to backbone + # (equivalent to slime's keep_graph=False change) + from megatron.core.utils import make_viewless_tensor + _hs = make_viewless_tensor( + inp=_hs.detach(), + requires_grad=True, + keep_graph=False, + ) + return _ids, _pos, _dec_input, _hs + + _layer._get_embeddings = _patched_get_embeddings + _mtp_get_emb_restore.append( + (_layer, _orig_get_emb) + ) + + self.logger.debug( + f"[MTPDetach] Patched _get_embeddings on " + f"{len(_mtp_get_emb_restore)} MTP layer(s) " + f"for embedding gradient isolation (Path 3)" + ) + + self.logger.info( + "[MTPDetach] Comprehensive MTP gradient isolation " + f"enabled (mtp_detach_heads={self.mtp_detach_heads}): " + "Path 1 (_MTPGradIsolator for backbone hidden_states), " + "Path 2 (detached output_weight + functional_call for lm_head), " + "Path 3 (detached decoder_input + hidden_states for embedding). " + "MTP CE loss gradients will NOT flow through backbone, " + "lm_head, or embedding parameters." + ) + else: + self.logger.info( + "[MTPDetach] Gradient isolation DISABLED " + "(mtp_detach_heads=False). MTP CE loss gradient " + "will flow through all model parameters. This is " + "intended for pre-training, NOT for RL training." + ) + # Monkey-patch: make the LAST call to # compute_language_model_loss (the main CE loss) # return logits so AReaL gets logits, not loss. @@ -1180,6 +1599,11 @@ def _ckpt_wrapper(*flat_args): extra_block_kwargs=extra_block_kwargs, ) finally: + if _postprocess_restore is not None: + _uw, _orig_pp = _postprocess_restore + _uw._postprocess = types.MethodType(_orig_pp, _uw) + for _layer, _orig_get_emb in _mtp_get_emb_restore: + _layer._get_embeddings = _orig_get_emb if _mtp_restore is not None: _uw, _sm, _sp, _sl = _mtp_restore _uw.mtp = _sm From b40a55bf95b61d25b05278d553985fb7968fd716 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 17 Apr 2026 13:35:45 +0800 Subject: [PATCH 042/140] feat: fix again --- areal/engine/megatron_engine.py | 115 +++++++++++++------------------- 1 file changed, 45 insertions(+), 70 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ae4986b575..415f2a93d0 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -831,39 +831,37 @@ def _collect_mtp_loss(self) -> dict[str, float]: # to verify gradient isolation is working correctly. if is_last_pp_stage and self.mtp_detach_heads: try: - mtp_grad_sq = 0.0 - non_mtp_grad_sq = 0.0 - mtp_cnt = 0 - non_mtp_cnt = 0 - emb_grad_sq = 0.0 - lmhead_grad_sq = 0.0 + mtp_g = 0.0 + non_mtp_g = 0.0 + mtp_n = 0 + non_mtp_n = 0 + emb_g = 0.0 + lmh_g = 0.0 for module in self.model: for name, param in module.named_parameters(): if param.grad is not None: g = param.grad.data.float().norm() ** 2 if ".mtp." in name: - mtp_grad_sq += g.item() - mtp_cnt += 1 + mtp_g += g.item() + mtp_n += 1 else: - non_mtp_grad_sq += g.item() - non_mtp_cnt += 1 + non_mtp_g += g.item() + non_mtp_n += 1 if "embedding" in name and ".mtp." not in name: - emb_grad_sq += g.item() + emb_g += g.item() if "output_layer" in name and ".mtp." not in name: - lmhead_grad_sq += g.item() + lmh_g += g.item() self.logger.info( - f"[MTPDetach] Gradient norms after backward: " - f"mtp_params={mtp_grad_sq**0.5:.6f} ({mtp_cnt}), " - f"non_mtp_params={non_mtp_grad_sq**0.5:.6f} ({non_mtp_cnt}), " - f"embedding={emb_grad_sq**0.5:.6f}, " - f"lm_head={lmhead_grad_sq**0.5:.6f}. " - f"All non-MTP norms should be ~GRPO scale only." + f"[MTPDetach] Gradient norms: " + f"mtp={mtp_g**0.5:.6f}({mtp_n}), " + f"non_mtp={non_mtp_g**0.5:.6f}({non_mtp_n}), " + f"emb={emb_g**0.5:.6f}, lmh={lmh_g**0.5:.6f}" ) - mtp_stats["mtp_grad_norm"] = mtp_grad_sq ** 0.5 - mtp_stats["non_mtp_grad_norm"] = non_mtp_grad_sq ** 0.5 + mtp_stats["mtp_grad_norm"] = mtp_g ** 0.5 + mtp_stats["non_mtp_grad_norm"] = non_mtp_g ** 0.5 except Exception as e: self.logger.warning( - f"[MTPDetach] Failed to compute gradient norms: {e}" + f"[MTPDetach] Grad norm logging failed: {e}" ) MTPLossLoggingHelper.clean_loss_in_tracker() @@ -1133,17 +1131,13 @@ def _patched_postprocess( _isolator=_MTPGradIsolator, _logger=self.logger, ): - """Patched _postprocess with comprehensive MTP gradient isolation. - - Identical to original except: - 1. _MTPGradIsolator after MTP loss loop (Path 1) - 2. output_weight.detach() for MTP loss computation (Path 2) - 3. functional_call with detached params for output_layer in MTP (Path 2) - Path 3 (embedding) is handled separately by _get_embeddings patch. + """Patched _postprocess with comprehensive MTP + gradient isolation (Paths 1, 2, 3). """ from megatron.core.transformer.multi_token_prediction import ( MTPLossAutoScaler, MTPLossLoggingHelper, + roll_tensor, ) in_inference_mode = ( @@ -1181,27 +1175,19 @@ def _patched_postprocess( return hidden_states if self_model.config.mtp_num_layers is not None: - from megatron.core.tensor_parallel.utils import ( - roll_tensor, - ) - mtp_labels = labels.clone() hidden_states_list = torch.chunk( hidden_states, 1 + self_model.config.mtp_num_layers, dim=0, ) - # === GRADIENT ISOLATION Path 1: save original === + # Path 1: save original hidden_states _original_hs = hidden_states_list[0] hidden_states = hidden_states_list[0] if loss_mask is None: loss_mask = torch.ones_like(mtp_labels) - # === GRADIENT ISOLATION Path 2: detach output weight === - # Use detached output_weight for MTP loss computation - # so MTP CE loss does NOT update lm_head weights. - # The main model's logit computation (below) uses the - # ORIGINAL (non-detached) output_weight. + # Path 2: detach output weight for MTP _mtp_output_weight = ( output_weight.detach() if output_weight is not None @@ -1211,14 +1197,11 @@ def _patched_postprocess( for mtp_layer_number in range( self_model.config.mtp_num_layers ): - # Use detached output_weight for MTP logits - # and functional_call with detached output_layer - # params to prevent gradient flow to lm_head. + # Path 2: functional_call with detached + # output_layer params for MTP logits _mtp_hs = hidden_states_list[ mtp_layer_number + 1 ] - # Path 2 fix: use functional_call with detached - # output_layer params for MTP logit computation. _ol = self_model.output_layer _ol_params = { k: v.detach() @@ -1227,7 +1210,9 @@ def _patched_postprocess( _ol_buffers = dict(_ol.named_buffers()) _ol_kwargs = { 'weight': _mtp_output_weight, - 'runtime_gather_output': runtime_gather_output, + 'runtime_gather_output': ( + runtime_gather_output + ), } mtp_logits, _ = torch.func.functional_call( _ol, @@ -1286,14 +1271,13 @@ def _patched_postprocess( / num_tokens, ) - # === GRADIENT ISOLATION Path 1: apply isolator === + # Path 1: apply gradient isolator hidden_states = _isolator.apply( _original_hs, hidden_states ) _logger.debug( - "[MTPDetach] Applied gradient isolation in " - "_postprocess (Path 1: _MTPGradIsolator, " - "Path 2: detached output_weight + functional_call)" + "[MTPDetach] Applied gradient isolation " + "in _postprocess (Paths 1+2)" ) # Inference last-token optimization @@ -1328,8 +1312,7 @@ def _patched_postprocess( ).unsqueeze(1) ) - # Main model logits: use ORIGINAL output_weight - # (non-detached) so GRPO gradient flows to lm_head. + # Main logits: ORIGINAL output_weight (GRPO grad flows) logits, _ = self_model.output_layer( hidden_states, weight=output_weight, @@ -1364,13 +1347,7 @@ def _patched_postprocess( _orig_postprocess, ) - # === GRADIENT ISOLATION Path 3: embedding detach === - # Patch _get_embeddings on each MTP layer to detach - # decoder_input after embedding computation. - # This prevents MTP CE loss gradient from flowing to - # shared embedding weights. - # Reference: slime adds decoder_input = decoder_input.detach() - # in multi_token_prediction.py _get_embeddings(). + # Path 3: patch _get_embeddings for embedding detach _mtp_block = getattr(_unwrapped, 'mtp', None) if ( _mtp_block is not None @@ -1387,12 +1364,9 @@ def _patched_get_embeddings( packed_seq_params=None, _orig=_orig_get_emb, ): - """Patched _get_embeddings that detaches decoder_input. - - Prevents MTP CE loss gradient from flowing to - shared embedding weights. Also uses keep_graph=False - for hidden_states viewless tensor to break gradient - connection to backbone. + """Detach decoder_input and hidden_states + to prevent MTP gradient from flowing to + shared embedding and backbone parameters. """ result = _orig( input_ids=input_ids, @@ -1401,13 +1375,11 @@ def _patched_get_embeddings( hidden_states=hidden_states, packed_seq_params=packed_seq_params, ) - # result = (input_ids, position_ids, decoder_input, hidden_states) _ids, _pos, _dec_input, _hs = result - # Detach decoder_input to prevent gradient flow to embedding _dec_input = _dec_input.detach() - # Detach hidden_states to prevent gradient flow to backbone - # (equivalent to slime's keep_graph=False change) - from megatron.core.utils import make_viewless_tensor + from megatron.core.utils import ( + make_viewless_tensor, + ) _hs = make_viewless_tensor( inp=_hs.detach(), requires_grad=True, @@ -1415,7 +1387,9 @@ def _patched_get_embeddings( ) return _ids, _pos, _dec_input, _hs - _layer._get_embeddings = _patched_get_embeddings + _layer._get_embeddings = ( + _patched_get_embeddings + ) _mtp_get_emb_restore.append( (_layer, _orig_get_emb) ) @@ -1600,8 +1574,9 @@ def _ckpt_wrapper(*flat_args): ) finally: if _postprocess_restore is not None: + import types as _types_mod _uw, _orig_pp = _postprocess_restore - _uw._postprocess = types.MethodType(_orig_pp, _uw) + _uw._postprocess = _types_mod.MethodType(_orig_pp, _uw) for _layer, _orig_get_emb in _mtp_get_emb_restore: _layer._get_embeddings = _orig_get_emb if _mtp_restore is not None: From ca39b2ed7ddcb3c8487eb3d5f75377aa8e918b8f Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 17 Apr 2026 15:14:22 +0800 Subject: [PATCH 043/140] =?UTF-8?q?feat=EF=BC=9A=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- areal/engine/megatron_engine.py | 62 ++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 415f2a93d0..c7e12dc27c 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -5,6 +5,7 @@ import gc import math import os +import random from collections.abc import Callable, Iterator from concurrent.futures import Future from contextlib import nullcontext @@ -837,25 +838,35 @@ def _collect_mtp_loss(self) -> dict[str, float]: non_mtp_n = 0 emb_g = 0.0 lmh_g = 0.0 + total_params = 0 + no_grad_params = 0 for module in self.model: for name, param in module.named_parameters(): - if param.grad is not None: - g = param.grad.data.float().norm() ** 2 - if ".mtp." in name: - mtp_g += g.item() - mtp_n += 1 - else: - non_mtp_g += g.item() - non_mtp_n += 1 - if "embedding" in name and ".mtp." not in name: - emb_g += g.item() - if "output_layer" in name and ".mtp." not in name: - lmh_g += g.item() + total_params += 1 + # Megatron DDP: main_grad > grad + grad = getattr(param, "main_grad", None) + if grad is None: + grad = param.grad + if grad is None: + no_grad_params += 1 + continue + g = grad.data.float().norm() ** 2 + if ".mtp." in name: + mtp_g += g.item() + mtp_n += 1 + else: + non_mtp_g += g.item() + non_mtp_n += 1 + if "embedding" in name and ".mtp." not in name: + emb_g += g.item() + if "output_layer" in name and ".mtp." not in name: + lmh_g += g.item() self.logger.info( - f"[MTPDetach] Gradient norms: " - f"mtp={mtp_g**0.5:.6f}({mtp_n}), " - f"non_mtp={non_mtp_g**0.5:.6f}({non_mtp_n}), " - f"emb={emb_g**0.5:.6f}, lmh={lmh_g**0.5:.6f}" + f"[MTPDetach] Gradient norms (main_grad): " + f"mtp={mtp_g**0.5:.6f}({mtp_n} params), " + f"non_mtp={non_mtp_g**0.5:.6f}({non_mtp_n} params), " + f"emb={emb_g**0.5:.6f}, lmh={lmh_g**0.5:.6f}, " + f"total={total_params}, no_grad={no_grad_params}" ) mtp_stats["mtp_grad_norm"] = mtp_g ** 0.5 mtp_stats["non_mtp_grad_norm"] = non_mtp_g ** 0.5 @@ -1400,15 +1411,16 @@ def _patched_get_embeddings( f"for embedding gradient isolation (Path 3)" ) - self.logger.info( - "[MTPDetach] Comprehensive MTP gradient isolation " - f"enabled (mtp_detach_heads={self.mtp_detach_heads}): " - "Path 1 (_MTPGradIsolator for backbone hidden_states), " - "Path 2 (detached output_weight + functional_call for lm_head), " - "Path 3 (detached decoder_input + hidden_states for embedding). " - "MTP CE loss gradients will NOT flow through backbone, " - "lm_head, or embedding parameters." - ) + if random.random() < 0.001: + self.logger.info( + "[MTPDetach] Comprehensive MTP gradient isolation " + f"enabled (mtp_detach_heads={self.mtp_detach_heads}): " + "Path 1 (_MTPGradIsolator for backbone hidden_states), " + "Path 2 (detached output_weight + functional_call for lm_head), " + "Path 3 (detached decoder_input + hidden_states for embedding). " + "MTP CE loss gradients will NOT flow through backbone, " + "lm_head, or embedding parameters." + ) else: self.logger.info( "[MTPDetach] Gradient isolation DISABLED " From 1889b1af7ed61bb1eb590bfad3a4c0bcec3801e4 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 18 Apr 2026 21:35:13 +0800 Subject: [PATCH 044/140] feat(engine): add mtp weight update --- areal/engine/megatron_engine.py | 90 +++++++++++++++++ areal/engine/sglang_remote.py | 103 ++++++++++++++++++++ areal/infra/remote_inf_engine.py | 55 +++++++++++ examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 4 files changed, 249 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index c7e12dc27c..28f26848c4 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -186,6 +186,7 @@ def __init__(self, config: TrainEngineConfig): self.mtp_detach_heads: bool = getattr(self.config, "mtp_detach_heads", True) self._mtp_loss_value: float = 0.0 self._mtp_layers_verified: bool = False + self._mtp_tensor_update_warned: bool = False if self.enable_mtp_training: self.logger.info( f"[MTPTrain] MTP online training ENABLED: " @@ -631,6 +632,15 @@ def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta): f"Connected rollout engine changed from {self.rollout_engine} to {engine}." ) self.rollout_engine = engine + # Check if engine supports tensor weight updates (MTP draft sync) + self._engine_supports_tensor_update = hasattr( + engine, "update_weights_from_tensor" + ) + if self.enable_mtp_training and self._engine_supports_tensor_update: + self.logger.info( + "[MTPTrain] Inference engine supports update_weights_from_tensor. " + "MTP draft model weights will be synced via tensor update path." + ) self.rollout_coordinator = DistRolloutCoordinator( rollout_engine=engine, train_engine=self ) @@ -2348,12 +2358,35 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: mtp_param_count = 0 mtp_param_bytes = 0 + # Collect MTP weights in HF format for draft model tensor update + mtp_hf_tensors = [] + _collect_mtp_for_draft = ( + self.enable_mtp_training + and getattr(self, "_engine_supports_tensor_update", False) + and self.is_pipeline_parallel_head() + ) for name, param in get_named_parameters(self.model, num_moe_experts): if ".experts." in name: continue if ".mtp." in name: mtp_param_count += 1 mtp_param_bytes += param.numel() * param.element_size() + if _collect_mtp_for_draft: + # Collect and all-gather MTP param, then convert to HF + # format. _collect_param handles TP all-gather, padding + # removal, and FP8 dequant (same as NCCL path). + _mtp_param, _ = self._collect_param(name, param) + _mtp_model_name = self.hf_config.model_type + mtp_hf_tensors.extend( + convert_to_hf( + self.tf_config, + _mtp_model_name, + name, + _mtp_param, + quantization_config=self.quantization_config, + fp8_direct_convert=self.fp8_direct_convert, + ) + ) if self.config.use_lora and ( ".adapter." not in name or not getattr(param, "requires_grad", False) ): @@ -2388,6 +2421,63 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"during weight sync at version={meta.version}. " f"MTP draft model weights will NOT be updated!" ) + + # --- MTP Draft Model Weight Update via /update_weights_from_tensor --- + # SGLang v0.5.9 /update_weights_from_distributed routes ONLY to + # tp_worker, which does NOT update the draft_worker (EAGLEWorker). + # MTP weights sent via NCCL are silently dropped by + # MiMoForCausalLM.load_weights() ("if mtp_layers in name: continue"). + # + # use /update_weights_from_tensor + # which routes to draft_worker. EAGLEWorker.update_weights_from_tensor() + # updates BOTH self.model_runner (draft/MiMoMTP) and + # self.target_worker.model_runner (target/MiMoForCausalLM). + # Each model's load_weights() silently skips non-matching names. + if _collect_mtp_for_draft and mtp_hf_tensors and dist.get_rank() == 0: + try: + tp_size = ( + meta.gen_allocation.parallel.tp_size + if meta.gen_allocation is not None + else 1 + ) + _mtp_bytes = sum( + t.numel() * t.element_size() + for _, t in mtp_hf_tensors + ) + self.logger.info( + f"[MTPTrain] Sending {len(mtp_hf_tensors)} MTP tensors " + f"({_mtp_bytes / 1024 / 1024:.2f} MB) to EAGLE draft model " + f"via /update_weights_from_tensor " + f"(tp_size={tp_size}, version={meta.version})" + ) + self.rollout_engine.update_weights_from_tensor( + named_tensors=mtp_hf_tensors, + tp_size=tp_size, + flush_cache=True, + ) + self.logger.info( + f"[MTPTrain] Successfully updated EAGLE draft model " + f"MTP weights at version={meta.version}" + ) + except Exception as e: + self.logger.error( + f"[MTPTrain] Failed to update EAGLE draft model " + f"MTP weights via tensor update: {e}. " + f"Draft model spec_accept_rate will degrade!" + ) + elif ( + self.enable_mtp_training + and not getattr(self, "_engine_supports_tensor_update", False) + and not self._mtp_tensor_update_warned + ): + self._mtp_tensor_update_warned = True + self.logger.warning( + "[MTPTrain] Inference engine does not support " + "update_weights_from_tensor. EAGLE draft model MTP weights " + "will NOT be updated, causing spec_accept_rate degradation. " + "Ensure SGLang backend is used with speculative decoding." + ) + dist.barrier(group=self.cpu_group) buffer_size = 0 diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 49b01fdbc9..17dae9a01a 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -7,6 +7,7 @@ from typing import Any import numpy as np +import torch import pybase64 from torchdata.stateful_dataloader import StatefulDataLoader @@ -200,6 +201,97 @@ def build_distributed_weight_update_requests( ] ) + def build_tensor_weight_update_request( + self, + named_tensors: list[tuple[str, torch.Tensor]], + tp_size: int = 1, + flush_cache: bool = True, + ) -> HttpRequest: + """Build HTTP request for /update_weights_from_tensor. + + Used to update EAGLE draft model weights (e.g., MTP layers) + that are not correctly routed by /update_weights_from_distributed. + Borrowed from slime/verl: both use /update_weights_from_tensor + to update draft_worker (EAGLEWorker) which syncs BOTH models. + + In SGLang v0.5.9: + - /update_weights_from_distributed -> tp_worker ONLY + - /update_weights_from_tensor -> draft_worker or tp_worker + EAGLEWorker.update_weights_from_tensor() updates both + self.model_runner (draft/MiMoMTP) and + self.target_worker.model_runner (target/MiMoForCausalLM). + Each model's load_weights() silently skips non-matching names. + + Args: + named_tensors: (name, tensor) pairs in HF format on GPU. + tp_size: Tensor parallel size of inference engine. + flush_cache: Whether to flush KV cache after update. + + Returns: + HttpRequest for /update_weights_from_tensor endpoint. + """ + try: + from sglang.srt.utils import MultiprocessingSerializer + from sglang.srt.utils.patch_torch import ( + monkey_patch_torch_reductions, + ) + except ImportError: + raise ImportError( + "SGLang >= 0.5.9 is required for tensor weight updates. " + "Install sglang to use MTP draft weight sync." + ) + + monkey_patch_torch_reductions() + + # Serialize each tensor via CUDA IPC (shared memory handles). + # Same approach used by verl/slime for colocated weight sync. + serialized_pairs = [ + (name, MultiprocessingSerializer.serialize(tensor.detach())) + for name, tensor in named_tensors + ] + + # Wrap in LocalSerializedTensor format expected by SGLang. + # Each TP rank receives the same full tensors; SGLang's + # model_runner.load_weights() handles TP slicing internally. + try: + from sglang.srt.model_executor.model_runner import ( + LocalSerializedTensor, + ) + except ImportError: + raise ImportError( + "Cannot import LocalSerializedTensor from SGLang. " + "Ensure sglang >= 0.5.9 is installed." + ) + + per_rank_named_tensors = [ + ( + name, + LocalSerializedTensor( + values=[serialized_data] * tp_size + ), + ) + for name, serialized_data in serialized_pairs + ] + + # Serialize the full named_tensors list per TP rank and + # base64-encode for JSON transport over HTTP. + import base64 + + serialized_named_tensors = [ + base64.b64encode( + MultiprocessingSerializer.serialize(per_rank_named_tensors) + ).decode("utf-8") + for _ in range(tp_size) + ] + + return HttpRequest( + endpoint="/update_weights_from_tensor", + payload={ + "serialized_named_tensors": serialized_named_tensors, + "flush_cache": flush_cache, + }, + ) + def build_init_weights_group_request( self, addr: str, server_idx: int, meta: WeightUpdateMeta ) -> HttpRequest: @@ -331,6 +423,17 @@ def update_weights_from_distributed( """Update weights from distributed memory.""" return self._engine.update_weights_from_distributed(meta, param_specs) + def update_weights_from_tensor( + self, + named_tensors: list[tuple[str, torch.Tensor]], + tp_size: int = 1, + flush_cache: bool = True, + ): + """Update EAGLE draft model weights via tensor update path.""" + return self._engine.update_weights_from_tensor( + named_tensors, tp_size, flush_cache + ) + def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: """Update weights from disk.""" return self._engine.update_weights_from_disk(meta) diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 80f631ef81..a9d1a4b1ec 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -210,6 +210,33 @@ def build_distributed_weight_update_requests( """ ... + def build_tensor_weight_update_request( + self, + named_tensors: list, + tp_size: int = 1, + flush_cache: bool = True, + ) -> "HttpRequest": + """Build HTTP request for /update_weights_from_tensor. + + Used to update EAGLE draft model weights that are not + correctly routed by /update_weights_from_distributed. + + Parameters + ---------- + named_tensors : list + (name, tensor) pairs in HF format + tp_size : int + Tensor parallel size + flush_cache : bool + Whether to flush KV cache after update + + Returns + ------- + HttpRequest + The HTTP request for tensor weight update + """ + ... + def build_init_weights_group_request( self, addr: str, server_idx: int, meta: WeightUpdateMeta ) -> HttpRequest: @@ -974,6 +1001,34 @@ def update_weights_from_distributed( return fut + def update_weights_from_tensor( + self, + named_tensors: list, + tp_size: int = 1, + flush_cache: bool = True, + ): + """Update EAGLE draft model weights via /update_weights_from_tensor. + + Sends MTP layer weights to SGLang server using the tensor update + path, which routes to draft_worker (EAGLEWorker) when speculative + decoding is enabled. EAGLEWorker updates both draft and target models. + + Borrowed from slime/verl approach for MTP draft weight sync. + + Parameters + ---------- + named_tensors : list + (name, tensor) pairs in HF format on GPU + tp_size : int + Tensor parallel size of the inference engine + flush_cache : bool + Whether to flush KV cache after update + """ + http_req = self.backend.build_tensor_weight_update_request( + named_tensors, tp_size, flush_cache + ) + self._run_request_on_all_servers(http_req) + def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: """Update weights in the inference engine from disk. diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 01c8f69219..eea77f6908 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -128,7 +128,7 @@ sglang: speculative_eagle_topk: 1 speculative_num_draft_tokens: 4 speculative_attention_mode: null - enable_draft_weights_cpu_backup: false + enable_draft_weights_cpu_backup: true vllm: From bbc9deb419f13159ba43a7625a8fe115c7d3e44c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 00:17:39 +0800 Subject: [PATCH 045/140] feat(mtp): fix mtp weight update --- areal/engine/megatron_engine.py | 81 +++++++++++++++++++- areal/engine/sglang_remote.py | 39 +++++++++- areal/infra/controller/rollout_callback.py | 41 ++++++++++ areal/infra/controller/rollout_controller.py | 31 ++++++++ areal/infra/remote_inf_engine.py | 26 ++++++- 5 files changed, 208 insertions(+), 10 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 28f26848c4..5734d0e614 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -632,7 +632,7 @@ def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta): f"Connected rollout engine changed from {self.rollout_engine} to {engine}." ) self.rollout_engine = engine - # Check if engine supports tensor weight updates (MTP draft sync) + # Check if engine supports tensor weight updates (MTP draft sync). self._engine_supports_tensor_update = hasattr( engine, "update_weights_from_tensor" ) @@ -2337,6 +2337,72 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta) -> None: self.engine_lock.release() + def _serialize_mtp_tensors_for_update( + self, + mtp_hf_tensors: list[tuple[str, "torch.Tensor"]], + tp_size: int, + ) -> dict: + """Serialize MTP tensors for /update_weights_from_tensor transport. + + Pre-serializes tensor data using SGLang's MultiprocessingSerializer + with CUDA IPC handles, then base64-encodes for JSON/HTTP transport. + This is required for single-controller mode where the engine proxy + (RolloutCallback) communicates via HTTP. + + Args: + mtp_hf_tensors: List of (name, tensor) pairs in HF format. + tp_size: Tensor parallel size of inference engine. + + Returns: + Dict with 'serialized_named_tensors' and 'flush_cache' keys, + ready for /update_weights_from_tensor endpoint. + """ + try: + from sglang.srt.utils import MultiprocessingSerializer + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + except ImportError: + raise ImportError( + "SGLang >= 0.5.9 is required for tensor weight updates. " + "Install sglang to use MTP draft weight sync." + ) + try: + from sglang.srt.model_executor.model_runner import LocalSerializedTensor + except ImportError: + raise ImportError( + "Cannot import LocalSerializedTensor from SGLang. " + "Ensure sglang >= 0.5.9 is installed." + ) + + monkey_patch_torch_reductions() + + # Inner serialization: each tensor → CUDA IPC handle bytes + serialized_pairs = [ + (name, MultiprocessingSerializer.serialize(tensor.detach())) + for name, tensor in mtp_hf_tensors + ] + + # Wrap in LocalSerializedTensor: one entry per TP rank. + # All TP ranks get the same data; SGLang load_weights() handles slicing. + per_rank_named_tensors = [ + (name, LocalSerializedTensor(values=[data] * tp_size)) + for name, data in serialized_pairs + ] + + # Outer serialization + base64 for JSON transport + import base64 + + serialized_named_tensors = [ + base64.b64encode( + MultiprocessingSerializer.serialize(per_rank_named_tensors) + ).decode("utf-8") + for _ in range(tp_size) + ] + + return { + "serialized_named_tensors": serialized_named_tensors, + "flush_cache": True, + } + @trace_perf("megatron_engine.update_weights_from_distributed", category="comm") def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: DeviceRuntimeInfo.get_current().log("_update_weights_from_distributed start") @@ -2450,9 +2516,15 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"via /update_weights_from_tensor " f"(tp_size={tp_size}, version={meta.version})" ) + # Serialize tensors on the training side. This is required + # for single-controller mode where the engine is a + # RolloutCallback proxy — tensor data must be serialized + # before it can travel through the HTTP callback chain. + serialized_payload = self._serialize_mtp_tensors_for_update( + mtp_hf_tensors, tp_size + ) self.rollout_engine.update_weights_from_tensor( - named_tensors=mtp_hf_tensors, - tp_size=tp_size, + serialized_payload=serialized_payload, flush_cache=True, ) self.logger.info( @@ -2463,7 +2535,8 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: self.logger.error( f"[MTPTrain] Failed to update EAGLE draft model " f"MTP weights via tensor update: {e}. " - f"Draft model spec_accept_rate will degrade!" + f"Draft model spec_accept_rate will degrade!", + exc_info=True, ) elif ( self.enable_mtp_training diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 17dae9a01a..b64e05efda 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -425,13 +425,44 @@ def update_weights_from_distributed( def update_weights_from_tensor( self, - named_tensors: list[tuple[str, torch.Tensor]], + named_tensors: list[tuple[str, torch.Tensor]] | None = None, tp_size: int = 1, flush_cache: bool = True, + serialized_payload: dict | None = None, + ): + """Update EAGLE draft model weights via tensor update path. + + Supports two modes: + 1. Raw tensors: pass named_tensors + tp_size (original path) + 2. Pre-serialized: pass serialized_payload dict (callback chain path) + + In single-controller mode, MegatronEngine pre-serializes the tensors + and passes serialized_payload to avoid GPU tensor serialization issues + when crossing process boundaries via RolloutCallback. + """ + if serialized_payload is not None: + # Pre-serialized path: send directly to SGLang server + return self._engine.update_weights_from_tensor_serialized( + serialized_payload + ) + else: + # Raw tensor path: build request from tensors + return self._engine.update_weights_from_tensor( + named_tensors, tp_size, flush_cache + ) + + def update_weights_from_tensor_serialized( + self, + serialized_payload: dict, ): - """Update EAGLE draft model weights via tensor update path.""" - return self._engine.update_weights_from_tensor( - named_tensors, tp_size, flush_cache + """Update EAGLE draft model weights with pre-serialized tensor data. + + Used in single-controller mode where the RolloutController + delegates to this method after receiving serialized data from + the training side via the callback chain. + """ + return self._engine.update_weights_from_tensor_serialized( + serialized_payload ) def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: diff --git a/areal/infra/controller/rollout_callback.py b/areal/infra/controller/rollout_callback.py index 7b6852348f..cada269315 100644 --- a/areal/infra/controller/rollout_callback.py +++ b/areal/infra/controller/rollout_callback.py @@ -178,3 +178,44 @@ def continue_generation(self) -> None: This is synchronous as it should complete before returning control. """ self._post("/callback/continue_generation") + + def update_weights_from_tensor( + self, + named_tensors: list | None = None, + tp_size: int = 1, + flush_cache: bool = True, + serialized_payload: dict | None = None, + ) -> None: + """Callback to controller to update EAGLE draft model MTP weights. + + In single-controller mode, tensor data is pre-serialized by the + MegatronEngine before calling this method. The serialized payload + (base64-encoded CUDA IPC handles) travels through the HTTP callback + chain to the RolloutController, which delegates to the inference + engine workers. + + This is synchronous (blocking) because the MTP tensor update + happens AFTER the main NCCL weight sync (within the pause window), + so there is no deadlock risk from blocking. + + Parameters + ---------- + named_tensors : list, optional + Ignored in callback mode — tensors must be pre-serialized. + tp_size : int + Ignored in callback mode — encoded in serialized_payload. + flush_cache : bool + Whether to flush KV cache after update. + serialized_payload : dict, optional + Pre-serialized payload for /update_weights_from_tensor endpoint. + """ + if serialized_payload is None: + raise ValueError( + "RolloutCallback.update_weights_from_tensor requires " + "serialized_payload (pre-serialized tensor data). " + "Raw tensor mode is not supported through the callback chain." + ) + payload = { + "serialized_payload": serialize_value(serialized_payload), + } + self._post("/callback/update_weights_tensor", payload) diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 8fde8e5fdf..a6f6df4b26 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -582,6 +582,17 @@ def continue_generation(): self._callback_loop.run_until_complete(self.continue_generation()) return jsonify({"status": "ok"}) + @app.route("/callback/update_weights_tensor", methods=["POST"]) + def update_weights_tensor(): + payload = request.get_json() or {} + serialized_payload = deserialize_value( + payload.get("serialized_payload") + ) + self._callback_loop.run_until_complete( + self.update_weights_from_tensor(serialized_payload) + ) + return jsonify({"status": "ok"}) + @app.route("/callback/rollout_complete", methods=["POST"]) def rollout_complete(): payload = request.get_json() or {} @@ -1037,6 +1048,26 @@ async def pause_generation(self): async def continue_generation(self): await self._collective_rpc_async("continue_generation") + async def update_weights_from_tensor( + self, serialized_payload: dict + ) -> None: + """Update EAGLE draft model MTP weights via tensor update path. + + Receives pre-serialized tensor data from the training side and + delegates to inference engine workers which send the serialized + payload directly to the SGLang server's /update_weights_from_tensor + endpoint. + + Parameters + ---------- + serialized_payload : dict + Pre-serialized payload for /update_weights_from_tensor. + """ + await self._collective_rpc_async( + "update_weights_from_tensor_serialized", + serialized_payload=serialized_payload, + ) + def set_version(self, version: int) -> None: with self._version_lock: self._version = version diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index a9d1a4b1ec..e332a5917c 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -1013,8 +1013,6 @@ def update_weights_from_tensor( path, which routes to draft_worker (EAGLEWorker) when speculative decoding is enabled. EAGLEWorker updates both draft and target models. - Borrowed from slime/verl approach for MTP draft weight sync. - Parameters ---------- named_tensors : list @@ -1029,6 +1027,30 @@ def update_weights_from_tensor( ) self._run_request_on_all_servers(http_req) + def update_weights_from_tensor_serialized( + self, + serialized_payload: dict, + ): + """Update EAGLE draft model weights with pre-serialized tensor data. + + Accepts a pre-serialized payload dict (with 'serialized_named_tensors' + and 'flush_cache' keys) and sends it directly to the SGLang server's + /update_weights_from_tensor endpoint. Used in single-controller mode + where tensor serialization happens on the training side. + + Parameters + ---------- + serialized_payload : dict + Pre-serialized payload for /update_weights_from_tensor endpoint. + Must contain 'serialized_named_tensors' (list of base64 strings) + and optionally 'flush_cache' (bool). + """ + http_req = HttpRequest( + endpoint="/update_weights_from_tensor", + payload=serialized_payload, + ) + self._run_request_on_all_servers(http_req) + def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: """Update weights in the inference engine from disk. From ee24e8d17a12e9cb9d2afa57d623502edc2700ff Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 13:08:49 +0800 Subject: [PATCH 046/140] fix(controller): fix callback --- areal/engine/megatron_engine.py | 19 ++++++- areal/infra/controller/rollout_callback.py | 41 +++++++++++++- areal/infra/controller/rollout_controller.py | 58 +++++++++++++++----- 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 5734d0e614..bc0e5e0b42 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2114,7 +2114,24 @@ def _update_bucket_weights_from_distributed( for handle in handles: handle.wait() - fut.result() + # The callback server now returns HTTP 200 immediately (fire-and-forget) + # before the NCCL transfer completes on the inference side. Since NCCL + # broadcast is collective, handle.wait() above already guarantees BOTH + # sides have completed the data transfer. fut.result() only confirms + # the HTTP POST was accepted. Use a short timeout to catch delivery + # errors without blocking on infrastructure proxy timeouts (504). + try: + fut.result(timeout=30) + except TimeoutError: + self.logger.warning( + "Callback response timed out, but NCCL broadcast " + "completed successfully. Continuing weight update." + ) + except Exception as e: + self.logger.warning( + f"Callback response error: {e}. NCCL broadcast " + "completed successfully. Continuing weight update." + ) converted_named_tensors.clear() diff --git a/areal/infra/controller/rollout_callback.py b/areal/infra/controller/rollout_callback.py index cada269315..cd18d8e71a 100644 --- a/areal/infra/controller/rollout_callback.py +++ b/areal/infra/controller/rollout_callback.py @@ -1,5 +1,5 @@ from concurrent.futures import Future -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any import requests @@ -29,6 +29,26 @@ class RolloutCallback: controller_addr: str request_timeout: float = 600.0 + # Bypass HTTP_PROXY / HTTPS_PROXY for all internal callback requests. + # In environments where a corporate HTTP proxy is configured (e.g., + # HTTP_PROXY=http://sys-proxy-rd-relay.byted.org:8118), Python's + # requests library auto-routes through the proxy. When the callback + # server's address (e.g., an IPv6 pod IP) is not listed in NO_PROXY, + # the proxy intercepts the request and applies its own timeout (~60s), + # causing 504 Gateway Timeout on long-running NCCL weight updates + # before the operation completes. Using a Session with trust_env=False + # ensures direct connection to the callback server. + _no_proxy_session: requests.Session | None = field(default=None, init=False, repr=False) + + @property + def _session(self) -> requests.Session: + """Lazily create a requests.Session that bypasses env proxy settings.""" + if self._no_proxy_session is None: + s = requests.Session() + s.trust_env = False # Ignore HTTP_PROXY / HTTPS_PROXY / NO_PROXY + object.__setattr__(self, "_no_proxy_session", s) + return self._no_proxy_session + def _post(self, endpoint: str, payload: dict[str, Any] | None = None) -> dict: """Make synchronous HTTP POST to controller callback endpoint. @@ -46,7 +66,7 @@ def _post(self, endpoint: str, payload: dict[str, Any] | None = None) -> dict: """ url = f"http://{self.controller_addr}{endpoint}" try: - resp = requests.post( + resp = self._session.post( url, json=payload or {}, timeout=self.request_timeout, @@ -218,4 +238,19 @@ def update_weights_from_tensor( payload = { "serialized_payload": serialize_value(serialized_payload), } - self._post("/callback/update_weights_tensor", payload) + # Use non-blocking POST as defense-in-depth. Even with proxy bypass, + # large MTP tensor payloads may take time to transmit. The fire-and- + # forget callback pattern ensures the HTTP layer does not block. + fut = self._post_nowait_void("/callback/update_weights_tensor", payload) + try: + fut.result(timeout=120) + except TimeoutError: + logger.warning( + "update_weights_from_tensor callback timed out. " + "Tensor update dispatched via fire-and-forget." + ) + except Exception as e: + logger.warning( + f"update_weights_from_tensor callback error: {e}. " + "Tensor update dispatched via fire-and-forget." + ) diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index a6f6df4b26..5a617b1e95 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -552,7 +552,9 @@ def _start_callback_server(self): def init_weights_group(): payload = request.get_json() or {} meta = deserialize_value(payload.get("meta")) - self._callback_loop.run_until_complete(self.init_weights_update_group(meta)) + asyncio.run_coroutine_threadsafe( + self.init_weights_update_group(meta), self._callback_loop + ).result() return jsonify({"status": "ok"}) @app.route("/callback/update_weights_xccl", methods=["POST"]) @@ -560,8 +562,15 @@ def update_weights(): payload = request.get_json() or {} meta = deserialize_value(payload.get("meta")) param_specs = deserialize_value(payload.get("param_specs")) - self._callback_loop.run_until_complete( - self.update_weights_from_distributed(meta, param_specs) + # Fire-and-forget: schedule the NCCL weight update as a background + # task and return HTTP 200 immediately. This prevents infrastructure + # proxy timeouts (504) since the full NCCL transfer chain can take + # >60s. NCCL broadcast is collective — when the training side's + # broadcast handle completes, the receive side has the data. + # Inspired by verl's pattern of decoupling HTTP from NCCL ops. + asyncio.run_coroutine_threadsafe( + self.update_weights_from_distributed(meta, param_specs), + self._callback_loop, ) return jsonify({"status": "ok"}) @@ -569,17 +578,23 @@ def update_weights(): def update_weights_disk(): payload = request.get_json() or {} meta = deserialize_value(payload.get("meta")) - self._callback_loop.run_until_complete(self.update_weights_from_disk(meta)) + asyncio.run_coroutine_threadsafe( + self.update_weights_from_disk(meta), self._callback_loop + ).result() return jsonify({"status": "ok"}) @app.route("/callback/pause_generation", methods=["POST"]) def pause_generation(): - self._callback_loop.run_until_complete(self.pause_generation()) + asyncio.run_coroutine_threadsafe( + self.pause_generation(), self._callback_loop + ).result() return jsonify({"status": "ok"}) @app.route("/callback/continue_generation", methods=["POST"]) def continue_generation(): - self._callback_loop.run_until_complete(self.continue_generation()) + asyncio.run_coroutine_threadsafe( + self.continue_generation(), self._callback_loop + ).result() return jsonify({"status": "ok"}) @app.route("/callback/update_weights_tensor", methods=["POST"]) @@ -588,8 +603,10 @@ def update_weights_tensor(): serialized_payload = deserialize_value( payload.get("serialized_payload") ) - self._callback_loop.run_until_complete( - self.update_weights_from_tensor(serialized_payload) + # Fire-and-forget: same pattern as update_weights_xccl. + asyncio.run_coroutine_threadsafe( + self.update_weights_from_tensor(serialized_payload), + self._callback_loop, ) return jsonify({"status": "ok"}) @@ -627,12 +644,28 @@ def handle_error(e): werkzeug_logger = stdlib_logging.getLogger("werkzeug") werkzeug_logger.setLevel(stdlib_logging.WARNING) - def serve_forever(): - # Create and set event loop for this thread + def run_async_loop(): + """Run a dedicated asyncio event loop in a background thread. + + This loop processes coroutines scheduled via + asyncio.run_coroutine_threadsafe(). Unlike the original design + which used run_until_complete() from the werkzeug handler thread, + a dedicated running loop supports both blocking (.result()) and + fire-and-forget patterns — critical for avoiding proxy/infra + timeouts on long-running NCCL weight transfers. + """ self._callback_loop = asyncio.new_event_loop() asyncio.set_event_loop(self._callback_loop) - # Signal that the loop is ready self._callback_loop_ready.set() + self._callback_loop.run_forever() + + self._callback_loop_thread = threading.Thread( + target=run_async_loop, daemon=True + ) + self._callback_loop_thread.start() + self._callback_loop_ready.wait() + + def serve_forever(): logger.info( f"Callback server started on {format_hostport(self._callback_host, self._callback_port)}" ) @@ -642,8 +675,6 @@ def serve_forever(): target=serve_forever, daemon=True ) self._callback_server_thread.start() - # Wait for loop to be created - self._callback_loop_ready.wait() def _stop_callback_server(self): """Stop the callback server if running.""" @@ -651,6 +682,7 @@ def _stop_callback_server(self): logger.info("Stopping callback server...") self._callback_server.shutdown() if self._callback_loop is not None: + self._callback_loop.call_soon_threadsafe(self._callback_loop.stop) self._callback_loop.close() self._callback_server = None self._callback_app = None From e03120647ab48d6ba3ffbcb5bf9eea227cb99df0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 14:50:50 +0800 Subject: [PATCH 047/140] fix(controller): skip _NO_PROXY --- areal/infra/controller/rollout_callback.py | 45 +++++++++++----------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/areal/infra/controller/rollout_callback.py b/areal/infra/controller/rollout_callback.py index cd18d8e71a..2bc56f54ad 100644 --- a/areal/infra/controller/rollout_callback.py +++ b/areal/infra/controller/rollout_callback.py @@ -1,5 +1,5 @@ from concurrent.futures import Future -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any import requests @@ -11,6 +11,11 @@ logger = logging.getLogger(__name__) +# Direct-connect proxy setting: tells requests to bypass all env proxies. +# This is used for every callback HTTP request to avoid corporate proxy +# interference (504 Gateway Timeout) on internal pod-to-pod communication. +_NO_PROXY = {"http": None, "https": None} + @dataclass class RolloutCallback: @@ -29,29 +34,24 @@ class RolloutCallback: controller_addr: str request_timeout: float = 600.0 - # Bypass HTTP_PROXY / HTTPS_PROXY for all internal callback requests. - # In environments where a corporate HTTP proxy is configured (e.g., - # HTTP_PROXY=http://sys-proxy-rd-relay.byted.org:8118), Python's - # requests library auto-routes through the proxy. When the callback - # server's address (e.g., an IPv6 pod IP) is not listed in NO_PROXY, - # the proxy intercepts the request and applies its own timeout (~60s), - # causing 504 Gateway Timeout on long-running NCCL weight updates - # before the operation completes. Using a Session with trust_env=False - # ensures direct connection to the callback server. - _no_proxy_session: requests.Session | None = field(default=None, init=False, repr=False) - - @property - def _session(self) -> requests.Session: - """Lazily create a requests.Session that bypasses env proxy settings.""" - if self._no_proxy_session is None: - s = requests.Session() - s.trust_env = False # Ignore HTTP_PROXY / HTTPS_PROXY / NO_PROXY - object.__setattr__(self, "_no_proxy_session", s) - return self._no_proxy_session - def _post(self, endpoint: str, payload: dict[str, Any] | None = None) -> dict: """Make synchronous HTTP POST to controller callback endpoint. + Uses ``proxies=_NO_PROXY`` to bypass environment proxy variables + (HTTP_PROXY / HTTPS_PROXY / NO_PROXY). In environments where a + corporate HTTP proxy is configured (e.g., + HTTP_PROXY=http://sys-proxy-rd-relay.byted.org:8118), Python's + ``requests`` library auto-routes through the proxy. When the + callback server address (e.g., an IPv6 pod IP) is not listed in + NO_PROXY, the proxy intercepts the request and applies its own + timeout (~60 s), causing 504 Gateway Timeout on long-running NCCL + weight updates. Passing ``proxies={"http": None, "https": None}`` + ensures a direct connection to the callback server on every call, + with zero extra state — which is critical because this dataclass is + serialized across RPC boundaries by AReaL's ``serialize_value`` / + ``deserialize_value`` (adding non-init fields would break + deserialization). + Parameters ---------- endpoint : str @@ -66,10 +66,11 @@ def _post(self, endpoint: str, payload: dict[str, Any] | None = None) -> dict: """ url = f"http://{self.controller_addr}{endpoint}" try: - resp = self._session.post( + resp = requests.post( url, json=payload or {}, timeout=self.request_timeout, + proxies=_NO_PROXY, ) resp.raise_for_status() return resp.json() From aaa3aa545a3325b22d717e78f9e21fafed8a2cf5 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 22:08:00 +0800 Subject: [PATCH 048/140] fix(controller): fix update --- areal/infra/controller/rollout_callback.py | 22 ++++++-------------- areal/infra/controller/rollout_controller.py | 3 +-- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/areal/infra/controller/rollout_callback.py b/areal/infra/controller/rollout_callback.py index 2bc56f54ad..c5ce826823 100644 --- a/areal/infra/controller/rollout_callback.py +++ b/areal/infra/controller/rollout_callback.py @@ -239,19 +239,9 @@ def update_weights_from_tensor( payload = { "serialized_payload": serialize_value(serialized_payload), } - # Use non-blocking POST as defense-in-depth. Even with proxy bypass, - # large MTP tensor payloads may take time to transmit. The fire-and- - # forget callback pattern ensures the HTTP layer does not block. - fut = self._post_nowait_void("/callback/update_weights_tensor", payload) - try: - fut.result(timeout=120) - except TimeoutError: - logger.warning( - "update_weights_from_tensor callback timed out. " - "Tensor update dispatched via fire-and-forget." - ) - except Exception as e: - logger.warning( - f"update_weights_from_tensor callback error: {e}. " - "Tensor update dispatched via fire-and-forget." - ) + # Synchronous (blocking) POST: the callback server handler now + # blocks until SGLang actually completes the tensor update, so the + # HTTP response guarantees the update is done. This ensures training + # does not proceed to continue_generation before the tensor update + # finishes, preventing worker RPC server contention and hangs. + self._post("/callback/update_weights_tensor", payload) diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 5a617b1e95..a11a67e78e 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -603,11 +603,10 @@ def update_weights_tensor(): serialized_payload = deserialize_value( payload.get("serialized_payload") ) - # Fire-and-forget: same pattern as update_weights_xccl. asyncio.run_coroutine_threadsafe( self.update_weights_from_tensor(serialized_payload), self._callback_loop, - ) + ).result() return jsonify({"status": "ok"}) @app.route("/callback/rollout_complete", methods=["POST"]) From 4d04c351ef7b100a09175271eb97c97ce7845bb2 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 00:32:25 +0800 Subject: [PATCH 049/140] feat(controller): add log --- areal/engine/megatron_engine.py | 195 ++++++++++--------- areal/infra/controller/rollout_callback.py | 57 +++++- areal/infra/controller/rollout_controller.py | 152 +++++++++++++-- areal/infra/remote_inf_engine.py | 70 ++++++- 4 files changed, 359 insertions(+), 115 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index bc0e5e0b42..51a8724e8e 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -480,7 +480,10 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): # Non-last stages legitimately have 0 MTP params. is_last_stage = True try: - if mpu.is_initialized() and mpu.get_pipeline_model_parallel_world_size() > 1: + if ( + mpu.is_initialized() + and mpu.get_pipeline_model_parallel_world_size() > 1 + ): is_last_stage = mpu.is_pipeline_last_stage() except Exception: pass @@ -878,8 +881,8 @@ def _collect_mtp_loss(self) -> dict[str, float]: f"emb={emb_g**0.5:.6f}, lmh={lmh_g**0.5:.6f}, " f"total={total_params}, no_grad={no_grad_params}" ) - mtp_stats["mtp_grad_norm"] = mtp_g ** 0.5 - mtp_stats["non_mtp_grad_norm"] = non_mtp_g ** 0.5 + mtp_stats["mtp_grad_norm"] = mtp_g**0.5 + mtp_stats["non_mtp_grad_norm"] = non_mtp_g**0.5 except Exception as e: self.logger.warning( f"[MTPDetach] Grad norm logging failed: {e}" @@ -1004,15 +1007,11 @@ def forward_step(batch_iter, model): # labels is None. Disable MTP at runtime in this case. if not self.enable_mtp_training and forward_only: _unwrapped_def = model - while hasattr(_unwrapped_def, 'module'): + while hasattr(_unwrapped_def, "module"): _unwrapped_def = _unwrapped_def.module - _def_mtp = getattr(_unwrapped_def, 'mtp', None) - _def_mtp_process = getattr( - _unwrapped_def, 'mtp_process', False - ) - _def_mtp_layers = getattr( - _unwrapped_def.config, 'mtp_num_layers', None - ) + _def_mtp = getattr(_unwrapped_def, "mtp", None) + _def_mtp_process = getattr(_unwrapped_def, "mtp_process", False) + _def_mtp_layers = getattr(_unwrapped_def.config, "mtp_num_layers", None) if ( _def_mtp is not None or _def_mtp_process @@ -1037,17 +1036,15 @@ def forward_step(batch_iter, model): if self.enable_mtp_training: _unwrapped = model - while hasattr(_unwrapped, 'module'): + while hasattr(_unwrapped, "module"): _unwrapped = _unwrapped.module if forward_only: # -- Inference: disable MTP to avoid crash -- - _saved_mtp = getattr(_unwrapped, 'mtp', None) - _saved_mtp_process = getattr( - _unwrapped, 'mtp_process', None - ) + _saved_mtp = getattr(_unwrapped, "mtp", None) + _saved_mtp_process = getattr(_unwrapped, "mtp_process", None) _saved_mtp_layers = getattr( - _unwrapped.config, 'mtp_num_layers', None + _unwrapped.config, "mtp_num_layers", None ) if ( _saved_mtp is not None @@ -1078,9 +1075,7 @@ def forward_step(batch_iter, model): _mtp_labels = _input_ids # loss_mask carried through pack/pad pipeline; # fall back to None → megatron uses ones_like. - _mtp_loss_mask = mb_input.padded_mb.get( - "loss_mask", None - ) + _mtp_loss_mask = mb_input.padded_mb.get("loss_mask", None) extra_block_kwargs = {"labels": _mtp_labels} if _mtp_loss_mask is not None: extra_block_kwargs["loss_mask"] = _mtp_loss_mask @@ -1220,18 +1215,15 @@ def _patched_postprocess( ): # Path 2: functional_call with detached # output_layer params for MTP logits - _mtp_hs = hidden_states_list[ - mtp_layer_number + 1 - ] + _mtp_hs = hidden_states_list[mtp_layer_number + 1] _ol = self_model.output_layer _ol_params = { - k: v.detach() - for k, v in _ol.named_parameters() + k: v.detach() for k, v in _ol.named_parameters() } _ol_buffers = dict(_ol.named_buffers()) _ol_kwargs = { - 'weight': _mtp_output_weight, - 'runtime_gather_output': ( + "weight": _mtp_output_weight, + "runtime_gather_output": ( runtime_gather_output ), } @@ -1256,10 +1248,8 @@ def _patched_postprocess( cp_group=self_model.cp_group, packed_seq_params=packed_seq_params, ) - mtp_loss = ( - self_model.compute_language_model_loss( - mtp_labels, mtp_logits - ) + mtp_loss = self_model.compute_language_model_loss( + mtp_labels, mtp_logits ) mtp_loss = loss_mask * mtp_loss if self_model.training: @@ -1287,9 +1277,7 @@ def _patched_postprocess( else: hidden_states = MTPLossAutoScaler.apply( hidden_states, - mtp_loss_scale - * mtp_loss - / num_tokens, + mtp_loss_scale * mtp_loss / num_tokens, ) # Path 1: apply gradient isolator @@ -1310,28 +1298,24 @@ def _patched_postprocess( if inference_context.is_static_batching(): hidden_states = hidden_states[-1:, :, :] else: - if ( - self_model.output_layer.sequence_parallel - ): + if self_model.output_layer.sequence_parallel: from megatron.core.tensor_parallel import ( gather_from_sequence_parallel_region, ) - hidden_states = gather_from_sequence_parallel_region( - hidden_states, - group=self_model.pg_collection.tp, + hidden_states = ( + gather_from_sequence_parallel_region( + hidden_states, + group=self_model.pg_collection.tp, + ) ) self_model.output_layer.sequence_parallel = ( False ) sequence_parallel_override = True - hidden_states = ( - inference_context.last_token_logits( - hidden_states.squeeze(1).unsqueeze( - 0 - ) - ).unsqueeze(1) - ) + hidden_states = inference_context.last_token_logits( + hidden_states.squeeze(1).unsqueeze(0) + ).unsqueeze(1) # Main logits: ORIGINAL output_weight (GRPO grad flows) logits, _ = self_model.output_layer( @@ -1346,9 +1330,7 @@ def _patched_postprocess( and inference_context.is_dynamic_batching() and inference_context.materialize_only_last_token_logits ) - self_model.output_layer.sequence_parallel = ( - True - ) + self_model.output_layer.sequence_parallel = True if labels is None: return logits.transpose(0, 1).contiguous() @@ -1369,11 +1351,8 @@ def _patched_postprocess( ) # Path 3: patch _get_embeddings for embedding detach - _mtp_block = getattr(_unwrapped, 'mtp', None) - if ( - _mtp_block is not None - and hasattr(_mtp_block, 'layers') - ): + _mtp_block = getattr(_unwrapped, "mtp", None) + if _mtp_block is not None and hasattr(_mtp_block, "layers"): for _layer in _mtp_block.layers: _orig_get_emb = _layer._get_embeddings @@ -1401,6 +1380,7 @@ def _patched_get_embeddings( from megatron.core.utils import ( make_viewless_tensor, ) + _hs = make_viewless_tensor( inp=_hs.detach(), requires_grad=True, @@ -1408,12 +1388,8 @@ def _patched_get_embeddings( ) return _ids, _pos, _dec_input, _hs - _layer._get_embeddings = ( - _patched_get_embeddings - ) - _mtp_get_emb_restore.append( - (_layer, _orig_get_emb) - ) + _layer._get_embeddings = _patched_get_embeddings + _mtp_get_emb_restore.append((_layer, _orig_get_emb)) self.logger.debug( f"[MTPDetach] Patched _get_embeddings on " @@ -1446,8 +1422,10 @@ def _patched_get_embeddings( _orig_clm = _unwrapped.compute_language_model_loss def _mtp_loss_fn( - _labels, _logits, - _rem=_remaining, _orig=_orig_clm, + _labels, + _logits, + _rem=_remaining, + _orig=_orig_clm, ): if _rem[0] > 0: _rem[0] -= 1 @@ -1472,17 +1450,18 @@ def _mtp_loss_fn( # We apply the same pattern here by monkey-patching each # MTP layer's _checkpointed_forward during training. # ----------------------------------------------------------- - _mtp_block = getattr(_unwrapped, 'mtp', None) + _mtp_block = getattr(_unwrapped, "mtp", None) if ( _mtp_block is not None - and hasattr(_mtp_block, 'layers') - and _unwrapped.config.recompute_granularity == 'full' + and hasattr(_mtp_block, "layers") + and _unwrapped.config.recompute_granularity == "full" ): for _layer in _mtp_block.layers: _orig_ckpt_fwd = _layer._checkpointed_forward def _patched_checkpointed_forward( - forward_func, *args, + forward_func, + *args, _layer_ref=_layer, **kwargs, ): @@ -1518,20 +1497,14 @@ def _ckpt_wrapper(*flat_args): _orig_args = flat_args[:n_orig] _tk_vals = flat_args[n_orig:] _rebuilt_kw = { - k: v for k, v in zip( - _tk_keys, _tk_vals - ) + k: v for k, v in zip(_tk_keys, _tk_vals) } _rebuilt_kw.update(_non_tensor_kw) - return forward_func( - *_orig_args, **_rebuilt_kw - ) + return forward_func(*_orig_args, **_rebuilt_kw) _cfg = _layer_ref.config - if _cfg.recompute_method == 'uniform': - assert ( - _cfg.recompute_num_layers == 1 - ), ( + if _cfg.recompute_method == "uniform": + assert _cfg.recompute_num_layers == 1, ( "recompute_num_layers must be 1 " "for MTP recompute" ) @@ -1539,6 +1512,7 @@ def _ckpt_wrapper(*flat_args): from megatron.core.extensions.transformer_engine import ( te_checkpoint, ) + return te_checkpoint( _ckpt_wrapper, _cfg.distribute_saved_activations, @@ -1554,8 +1528,9 @@ def _ckpt_wrapper(*flat_args): *args, *_tensor_kw.values(), ) - elif _cfg.recompute_method == 'block': + elif _cfg.recompute_method == "block": import warnings + warnings.warn( "recompute_method == 'block' is not " "supported for MTP yet. " @@ -1567,12 +1542,8 @@ def _ckpt_wrapper(*flat_args): "Invalid activation recompute method." ) - _layer._checkpointed_forward = ( - _patched_checkpointed_forward - ) - _mtp_ckpt_restore.append( - (_layer, _orig_ckpt_fwd) - ) + _layer._checkpointed_forward = _patched_checkpointed_forward + _mtp_ckpt_restore.append((_layer, _orig_ckpt_fwd)) self.logger.debug( f"[MTPTrain] Patched _checkpointed_forward on " @@ -1591,12 +1562,14 @@ def _ckpt_wrapper(*flat_args): try: output = packed_context_parallel_forward( - model, mb_input.padded_mb, + model, + mb_input.padded_mb, extra_block_kwargs=extra_block_kwargs, ) finally: if _postprocess_restore is not None: import types as _types_mod + _uw, _orig_pp = _postprocess_restore _uw._postprocess = _types_mod.MethodType(_orig_pp, _uw) for _layer, _orig_get_emb in _mtp_get_emb_restore: @@ -1606,9 +1579,7 @@ def _ckpt_wrapper(*flat_args): _uw.mtp = _sm _uw.mtp_process = _sp _uw.config.mtp_num_layers = _sl - self.logger.debug( - "[MTPTrain] Restored MTP after inference forward" - ) + self.logger.debug("[MTPTrain] Restored MTP after inference forward") if _clm_loss_restore is not None: _uw, _orig = _clm_loss_restore _uw.compute_language_model_loss = _orig @@ -2356,7 +2327,7 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta) -> None: def _serialize_mtp_tensors_for_update( self, - mtp_hf_tensors: list[tuple[str, "torch.Tensor"]], + mtp_hf_tensors: list[tuple[str, torch.Tensor]], tp_size: int, ) -> dict: """Serialize MTP tensors for /update_weights_from_tensor transport. @@ -2524,9 +2495,10 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: else 1 ) _mtp_bytes = sum( - t.numel() * t.element_size() - for _, t in mtp_hf_tensors + t.numel() * t.element_size() for _, t in mtp_hf_tensors ) + import time as _time + self.logger.info( f"[MTPTrain] Sending {len(mtp_hf_tensors)} MTP tensors " f"({_mtp_bytes / 1024 / 1024:.2f} MB) to EAGLE draft model " @@ -2537,16 +2509,55 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: # for single-controller mode where the engine is a # RolloutCallback proxy — tensor data must be serialized # before it can travel through the HTTP callback chain. + _t_ser0 = _time.time() + self.logger.info( + f"[MTPTrain][Diag] Starting _serialize_mtp_tensors_for_update " + f"(n_tensors={len(mtp_hf_tensors)}, tp_size={tp_size})..." + ) serialized_payload = self._serialize_mtp_tensors_for_update( mtp_hf_tensors, tp_size ) + _t_ser1 = _time.time() + # Log serialized payload info + _sp_keys = ( + list(serialized_payload.keys()) + if isinstance(serialized_payload, dict) + else "N/A" + ) + _n_snt = ( + len(serialized_payload.get("serialized_named_tensors", [])) + if isinstance(serialized_payload, dict) + else 0 + ) + _snt_sizes = ( + [ + len(s) + for s in serialized_payload.get("serialized_named_tensors", []) + ] + if isinstance(serialized_payload, dict) + else [] + ) + self.logger.info( + f"[MTPTrain][Diag] Serialization completed in {_t_ser1 - _t_ser0:.3f}s. " + f"payload_keys={_sp_keys}, n_serialized_tensors={_n_snt}, " + f"serialized_tensor_sizes_bytes={_snt_sizes}, " + f"rollout_engine_type={type(self.rollout_engine).__name__}" + ) + _t_call0 = _time.time() + self.logger.info( + "[MTPTrain][Diag] Calling rollout_engine.update_weights_from_tensor()..." + ) self.rollout_engine.update_weights_from_tensor( serialized_payload=serialized_payload, flush_cache=True, ) + _t_call1 = _time.time() self.logger.info( f"[MTPTrain] Successfully updated EAGLE draft model " - f"MTP weights at version={meta.version}" + f"MTP weights at version={meta.version} " + f"(serialize={_t_ser1 - _t_ser0:.3f}s, " + f"update_call={_t_call1 - _t_call0:.3f}s, " + f"total={_t_call1 - _t_ser0:.3f}s)" ) except Exception as e: self.logger.error( diff --git a/areal/infra/controller/rollout_callback.py b/areal/infra/controller/rollout_callback.py index c5ce826823..030dfffad3 100644 --- a/areal/infra/controller/rollout_callback.py +++ b/areal/infra/controller/rollout_callback.py @@ -1,3 +1,5 @@ +import sys +import time from concurrent.futures import Future from dataclasses import dataclass from typing import Any @@ -65,6 +67,11 @@ def _post(self, endpoint: str, payload: dict[str, Any] | None = None) -> dict: Response JSON from controller """ url = f"http://{self.controller_addr}{endpoint}" + _t_post = time.time() + logger.info( + f"[DiagMTP][Callback] _post: sending POST to {url} " + f"(timeout={self.request_timeout}s, proxies={_NO_PROXY})" + ) try: resp = requests.post( url, @@ -72,10 +79,18 @@ def _post(self, endpoint: str, payload: dict[str, Any] | None = None) -> dict: timeout=self.request_timeout, proxies=_NO_PROXY, ) + _elapsed = time.time() - _t_post + logger.info( + f"[DiagMTP][Callback] _post: response received from {url} " + f"in {_elapsed:.3f}s (status={resp.status_code})" + ) resp.raise_for_status() return resp.json() except requests.RequestException as e: - logger.error(f"Callback to {url} failed: {e}") + _elapsed = time.time() - _t_post + logger.error( + f"[DiagMTP][Callback] _post: FAILED {url} after {_elapsed:.3f}s: {e}" + ) raise def _post_nowait( @@ -236,12 +251,40 @@ def update_weights_from_tensor( "serialized_payload (pre-serialized tensor data). " "Raw tensor mode is not supported through the callback chain." ) + _t0 = time.time() + logger.info( + f"[DiagMTP][Callback] update_weights_from_tensor ENTERED. " + f"serialized_payload keys={list(serialized_payload.keys())}, " + f"controller_addr={self.controller_addr}" + ) payload = { "serialized_payload": serialize_value(serialized_payload), } - # Synchronous (blocking) POST: the callback server handler now - # blocks until SGLang actually completes the tensor update, so the - # HTTP response guarantees the update is done. This ensures training - # does not proceed to continue_generation before the tensor update - # finishes, preventing worker RPC server contention and hangs. - self._post("/callback/update_weights_tensor", payload) + _payload_size = sys.getsizeof(str(payload)) + _t1 = time.time() + logger.info( + f"[DiagMTP][Callback] serialize_value took {_t1 - _t0:.3f}s, " + f"payload_approx_size={_payload_size} bytes" + ) + + # Synchronous blocking POST: wait for callback server to complete + # the tensor update before returning. This ensures training does + # not proceed to continue_generation before the update finishes. + logger.info( + f"[DiagMTP][Callback] Calling _post('/callback/update_weights_tensor') " + f"with timeout={self.request_timeout}s..." + ) + try: + self._post("/callback/update_weights_tensor", payload) + _t2 = time.time() + logger.info( + f"[DiagMTP][Callback] _post completed in {_t2 - _t1:.3f}s " + f"(total: {_t2 - _t0:.3f}s)" + ) + except Exception as e: + _t2 = time.time() + logger.error( + f"[DiagMTP][Callback] _post FAILED after {_t2 - _t1:.3f}s " + f"(total: {_t2 - _t0:.3f}s): {type(e).__name__}: {e}" + ) + raise diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index a11a67e78e..e86a926801 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -3,6 +3,7 @@ import asyncio import shutil import threading +import time import traceback from collections import defaultdict from collections.abc import Callable @@ -559,9 +560,15 @@ def init_weights_group(): @app.route("/callback/update_weights_xccl", methods=["POST"]) def update_weights(): + _t0 = time.time() payload = request.get_json() or {} meta = deserialize_value(payload.get("meta")) param_specs = deserialize_value(payload.get("param_specs")) + _n_specs = len(param_specs) if param_specs else 0 + logger.info( + f"[DiagMTP] /callback/update_weights_xccl ENTERED " + f"(n_param_specs={_n_specs}, version={getattr(meta, 'version', '?')})" + ) # Fire-and-forget: schedule the NCCL weight update as a background # task and return HTTP 200 immediately. This prevents infrastructure # proxy timeouts (504) since the full NCCL transfer chain can take @@ -572,6 +579,10 @@ def update_weights(): self.update_weights_from_distributed(meta, param_specs), self._callback_loop, ) + logger.info( + f"[DiagMTP] /callback/update_weights_xccl returning HTTP 200 " + f"(fire-and-forget, handler took {time.time() - _t0:.3f}s)" + ) return jsonify({"status": "ok"}) @app.route("/callback/update_weights_disk", methods=["POST"]) @@ -585,28 +596,85 @@ def update_weights_disk(): @app.route("/callback/pause_generation", methods=["POST"]) def pause_generation(): + _t0 = time.time() + logger.info("[DiagMTP] /callback/pause_generation ENTERED") asyncio.run_coroutine_threadsafe( self.pause_generation(), self._callback_loop ).result() + logger.info( + f"[DiagMTP] /callback/pause_generation completed in {time.time() - _t0:.3f}s" + ) return jsonify({"status": "ok"}) @app.route("/callback/continue_generation", methods=["POST"]) def continue_generation(): + _t0 = time.time() + logger.info("[DiagMTP] /callback/continue_generation ENTERED") asyncio.run_coroutine_threadsafe( self.continue_generation(), self._callback_loop ).result() + logger.info( + f"[DiagMTP] /callback/continue_generation completed in {time.time() - _t0:.3f}s" + ) return jsonify({"status": "ok"}) @app.route("/callback/update_weights_tensor", methods=["POST"]) def update_weights_tensor(): + _t0 = time.time() + logger.info( + "[DiagMTP] /callback/update_weights_tensor handler ENTERED " + f"(flask_thread={threading.current_thread().name})" + ) payload = request.get_json() or {} - serialized_payload = deserialize_value( - payload.get("serialized_payload") + _t1 = time.time() + logger.info( + f"[DiagMTP] payload parsed in {_t1 - _t0:.3f}s, " + f"payload_keys={list(payload.keys())}, " + f"payload_size_bytes={len(str(payload))}" ) - asyncio.run_coroutine_threadsafe( - self.update_weights_from_tensor(serialized_payload), - self._callback_loop, - ).result() + serialized_payload = deserialize_value(payload.get("serialized_payload")) + _t2 = time.time() + logger.info( + f"[DiagMTP] deserialize_value completed in {_t2 - _t1:.3f}s, " + f"serialized_payload type={type(serialized_payload).__name__}, " + f"keys={list(serialized_payload.keys()) if isinstance(serialized_payload, dict) else 'N/A'}" + ) + # Check callback_loop health before scheduling + _loop = self._callback_loop + _loop_running = _loop is not None and _loop.is_running() + _loop_closed = _loop is not None and _loop.is_closed() + logger.info( + f"[DiagMTP] _callback_loop status: running={_loop_running}, " + f"closed={_loop_closed}, loop={_loop}" + ) + logger.info( + "[DiagMTP] Scheduling update_weights_from_tensor coroutine " + "on _callback_loop (BLOCKING with .result())..." + ) + _t3 = time.time() + try: + fut = asyncio.run_coroutine_threadsafe( + self.update_weights_from_tensor(serialized_payload), + self._callback_loop, + ) + logger.info( + f"[DiagMTP] Coroutine scheduled in {time.time() - _t3:.3f}s, " + f"fut={fut}, calling .result() to block..." + ) + fut.result() + _t4 = time.time() + logger.info( + f"[DiagMTP] .result() completed in {_t4 - _t3:.3f}s " + f"(total handler time: {_t4 - _t0:.3f}s). " + "Returning HTTP 200." + ) + except Exception as e: + logger.error( + f"[DiagMTP] .result() raised exception after " + f"{time.time() - _t3:.3f}s: {type(e).__name__}: {e}", + exc_info=True, + ) + raise return jsonify({"status": "ok"}) @app.route("/callback/rollout_complete", methods=["POST"]) @@ -731,6 +799,14 @@ async def _generic_collective_rpc_async( *args, **kwargs, ) -> list[Any]: + import time as _time + + _t0 = _time.time() + _worker_ids = [w.id for w in workers] + logger.info( + f"[DiagMTP] _generic_collective_rpc_async ENTERED: " + f"method={method}, n_workers={len(workers)}, workers={_worker_ids}" + ) tasks = [ self.scheduler.async_call_engine( worker_id=worker.id, @@ -741,7 +817,26 @@ async def _generic_collective_rpc_async( ) for rank, worker in enumerate(workers) ] - return await asyncio.gather(*tasks) + logger.info( + f"[DiagMTP] _generic_collective_rpc_async: " + f"{len(tasks)} tasks created for method={method}, " + f"calling asyncio.gather..." + ) + try: + results = await asyncio.gather(*tasks) + logger.info( + f"[DiagMTP] _generic_collective_rpc_async COMPLETED: " + f"method={method} in {_time.time() - _t0:.3f}s" + ) + return results + except Exception as e: + logger.error( + f"[DiagMTP] _generic_collective_rpc_async FAILED: " + f"method={method} after {_time.time() - _t0:.3f}s: " + f"{type(e).__name__}: {e}", + exc_info=True, + ) + raise def _choose_worker(self) -> tuple[Worker, int]: """Choose a worker for the next request using round-robin scheduling. @@ -1064,9 +1159,21 @@ async def init_weights_update_group(self, meta: WeightUpdateMeta) -> None: async def update_weights_from_distributed( self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] ): + import time as _time + + _t0 = _time.time() + _n_specs = len(param_specs) if param_specs else 0 + logger.info( + f"[DiagMTP] async update_weights_from_distributed ENTERED " + f"(n_specs={_n_specs}, version={getattr(meta, 'version', '?')})" + ) await self._collective_rpc_async( "update_weights_from_distributed", meta=meta, param_specs=param_specs ) + logger.info( + f"[DiagMTP] async update_weights_from_distributed COMPLETED " + f"in {_time.time() - _t0:.3f}s" + ) async def update_weights_from_disk(self, meta: WeightUpdateMeta): meta.clear_checkpoint_after_load = False @@ -1079,9 +1186,7 @@ async def pause_generation(self): async def continue_generation(self): await self._collective_rpc_async("continue_generation") - async def update_weights_from_tensor( - self, serialized_payload: dict - ) -> None: + async def update_weights_from_tensor(self, serialized_payload: dict) -> None: """Update EAGLE draft model MTP weights via tensor update path. Receives pre-serialized tensor data from the training side and @@ -1094,10 +1199,31 @@ async def update_weights_from_tensor( serialized_payload : dict Pre-serialized payload for /update_weights_from_tensor. """ - await self._collective_rpc_async( - "update_weights_from_tensor_serialized", - serialized_payload=serialized_payload, + import time as _time + + _t0 = _time.time() + _n_workers = len(self.workers) + _worker_ids = [w.id for w in self.workers] + logger.info( + f"[DiagMTP] async update_weights_from_tensor ENTERED on " + f"_callback_loop (n_workers={_n_workers}, workers={_worker_ids})" ) + try: + await self._collective_rpc_async( + "update_weights_from_tensor_serialized", + serialized_payload=serialized_payload, + ) + logger.info( + f"[DiagMTP] async update_weights_from_tensor COMPLETED " + f"in {_time.time() - _t0:.3f}s" + ) + except Exception as e: + logger.error( + f"[DiagMTP] async update_weights_from_tensor FAILED " + f"after {_time.time() - _t0:.3f}s: {type(e).__name__}: {e}", + exc_info=True, + ) + raise def set_version(self, version: int) -> None: with self._version_lock: diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index e332a5917c..09acf0dd96 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -215,7 +215,7 @@ def build_tensor_weight_update_request( named_tensors: list, tp_size: int = 1, flush_cache: bool = True, - ) -> "HttpRequest": + ) -> HttpRequest: """Build HTTP request for /update_weights_from_tensor. Used to update EAGLE draft model weights that are not @@ -1045,11 +1045,45 @@ def update_weights_from_tensor_serialized( Must contain 'serialized_named_tensors' (list of base64 strings) and optionally 'flush_cache' (bool). """ + import time as _time + + _t0 = _time.time() + _payload_keys = ( + list(serialized_payload.keys()) + if isinstance(serialized_payload, dict) + else "N/A" + ) + _n_tensors = ( + len(serialized_payload.get("serialized_named_tensors", [])) + if isinstance(serialized_payload, dict) + else 0 + ) + logger.info( + f"[DiagMTP][Worker] update_weights_from_tensor_serialized ENTERED: " + f"payload_keys={_payload_keys}, n_serialized_tensors={_n_tensors}, " + f"addresses={self.addresses}" + ) http_req = HttpRequest( endpoint="/update_weights_from_tensor", payload=serialized_payload, ) - self._run_request_on_all_servers(http_req) + logger.info( + f"[DiagMTP][Worker] Calling _run_request_on_all_servers for " + f"/update_weights_from_tensor to {len(self.addresses)} SGLang servers..." + ) + try: + self._run_request_on_all_servers(http_req) + logger.info( + f"[DiagMTP][Worker] update_weights_from_tensor_serialized " + f"COMPLETED in {_time.time() - _t0:.3f}s" + ) + except Exception as e: + logger.error( + f"[DiagMTP][Worker] update_weights_from_tensor_serialized " + f"FAILED after {_time.time() - _t0:.3f}s: {type(e).__name__}: {e}", + exc_info=True, + ) + raise def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: """Update weights in the inference engine from disk. @@ -1317,7 +1351,16 @@ def onload(self, tags: list[str] | None = None) -> None: self._run_request_on_all_servers(onload_req) def _run_request_on_all_servers(self, req: HttpRequest): + import time as _time + async def _fn(): + _t0 = _time.time() + logger.info( + f"[DiagMTP][Worker] _run_request_on_all_servers async _fn ENTERED: " + f"endpoint={req.endpoint}, n_addrs={len(self.addresses)}, " + f"addrs={self.addresses}, " + f"request_timeout={self.config.request_timeout}s" + ) async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self.config.request_timeout), read_bufsize=1024 * 1024 * 10, @@ -1325,6 +1368,10 @@ async def _fn(): ) as session: jobs = [] for addr in self.addresses: + logger.info( + f"[DiagMTP][Worker] Creating request job: " + f"{req.method} {addr}{req.endpoint}" + ) jobs.append( arequest_with_retry( session=session, @@ -1336,7 +1383,24 @@ async def _fn(): timeout=self.config.request_timeout, ) ) - await asyncio.gather(*jobs) + logger.info( + f"[DiagMTP][Worker] Dispatching {len(jobs)} HTTP jobs " + f"via asyncio.gather for {req.endpoint}..." + ) + try: + await asyncio.gather(*jobs) + logger.info( + f"[DiagMTP][Worker] asyncio.gather completed for " + f"{req.endpoint} in {_time.time() - _t0:.3f}s" + ) + except Exception as e: + logger.error( + f"[DiagMTP][Worker] asyncio.gather FAILED for " + f"{req.endpoint} after {_time.time() - _t0:.3f}s: " + f"{type(e).__name__}: {e}", + exc_info=True, + ) + raise uvloop.run(_fn()) From a1c3e82b30fad2dd75c4fc773d987c9dd9e500c4 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 10:16:51 +0800 Subject: [PATCH 050/140] fix(engine): cuda ipc sync --- areal/engine/megatron_engine.py | 9 ++++ areal/infra/controller/rollout_callback.py | 8 +-- areal/infra/controller/rollout_controller.py | 55 +++++++++++++++++++- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 51a8724e8e..9bc98fcf32 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2363,6 +2363,15 @@ def _serialize_mtp_tensors_for_update( monkey_patch_torch_reductions() + # Ensure all pending CUDA operations (e.g., NCCL all-gather from + # _collect_param) are complete before creating CUDA IPC handles. + # cudaIpcGetMemHandle requires the source tensor's GPU memory to be + # stable; without sync, pending async NCCL ops can cause the IPC + # handle creation to block indefinitely. + import torch + + torch.cuda.synchronize() + # Inner serialization: each tensor → CUDA IPC handle bytes serialized_pairs = [ (name, MultiprocessingSerializer.serialize(tensor.detach())) diff --git a/areal/infra/controller/rollout_callback.py b/areal/infra/controller/rollout_callback.py index 030dfffad3..1f9f3f959b 100644 --- a/areal/infra/controller/rollout_callback.py +++ b/areal/infra/controller/rollout_callback.py @@ -267,9 +267,11 @@ def update_weights_from_tensor( f"payload_approx_size={_payload_size} bytes" ) - # Synchronous blocking POST: wait for callback server to complete - # the tensor update before returning. This ensures training does - # not proceed to continue_generation before the update finishes. + # Synchronous blocking POST: MTP tensor update must complete before + # training proceeds to continue_generation. This follows verl/slime's + # approach of fully blocking weight updates. The callback server + # handler is also blocking (.result()), so HTTP 200 guarantees the + # tensor update is finished on the inference side. logger.info( f"[DiagMTP][Callback] Calling _post('/callback/update_weights_tensor') " f"with timeout={self.request_timeout}s..." diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index e86a926801..604d985f20 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import concurrent.futures import shutil import threading import time @@ -114,6 +115,12 @@ def __init__( self._callback_loop: asyncio.AbstractEventLoop | None = None self._callback_loop_ready = threading.Event() + # Pending fire-and-forget NCCL update futures. Tracked so the + # tensor update handler can drain them before dispatching its own + # RPC, preventing engine-thread queue starvation. + self._pending_xccl_futures: list[concurrent.futures.Future] = [] + self._xccl_futures_lock = threading.Lock() + # Task completion futures self._pending_futures: dict[int, asyncio.Future] = {} self._futures_lock = threading.Lock() @@ -575,10 +582,18 @@ def update_weights(): # >60s. NCCL broadcast is collective — when the training side's # broadcast handle completes, the receive side has the data. # Inspired by verl's pattern of decoupling HTTP from NCCL ops. - asyncio.run_coroutine_threadsafe( + fut = asyncio.run_coroutine_threadsafe( self.update_weights_from_distributed(meta, param_specs), self._callback_loop, ) + # Track future so tensor update can drain pending NCCL work + # before dispatching its own RPC to the engine thread queue. + with self._xccl_futures_lock: + # Prune completed futures to avoid unbounded growth + self._pending_xccl_futures = [ + f for f in self._pending_xccl_futures if not f.done() + ] + self._pending_xccl_futures.append(fut) logger.info( f"[DiagMTP] /callback/update_weights_xccl returning HTTP 200 " f"(fire-and-forget, handler took {time.time() - _t0:.3f}s)" @@ -639,6 +654,11 @@ def update_weights_tensor(): f"serialized_payload type={type(serialized_payload).__name__}, " f"keys={list(serialized_payload.keys()) if isinstance(serialized_payload, dict) else 'N/A'}" ) + # BLOCKING: MTP tensor update must complete before returning. + # Following verl/slime's fully-blocking weight update pattern. + # Unlike NCCL updates (fire-and-forget for concurrent collective + # participation), tensor updates are rank-0-only unilateral + # operations that can safely block. # Check callback_loop health before scheduling _loop = self._callback_loop _loop_running = _loop is not None and _loop.is_running() @@ -1194,11 +1214,44 @@ async def update_weights_from_tensor(self, serialized_payload: dict) -> None: payload directly to the SGLang server's /update_weights_from_tensor endpoint. + Before dispatching the tensor update RPC, drains all pending + fire-and-forget NCCL update futures. This ensures the worker's + engine thread queue is clear, preventing the tensor update from + being queued behind slow NCCL tasks (which would cause an + indefinite hang). Follows verl/slime's pattern of fully completing + all weight updates before proceeding. + Parameters ---------- serialized_payload : dict Pre-serialized payload for /update_weights_from_tensor. """ + # Drain all pending NCCL update futures before dispatching the + # tensor update. The NCCL updates and tensor update both go through + # the worker's single engine thread queue (via async_call_engine → + # /call → _submit_to_engine_thread). If NCCL tasks are still queued, + # the tensor update gets stuck behind them indefinitely. + with self._xccl_futures_lock: + pending = list(self._pending_xccl_futures) + self._pending_xccl_futures.clear() + + if pending: + logger.info( + f"[DiagMTP] Draining {len(pending)} pending NCCL futures " + f"before tensor update..." + ) + # Wait for all pending NCCL coroutines to complete. + # Use asyncio wrap to avoid blocking the event loop. + done_count = 0 + for fut in pending: + try: + await asyncio.wrap_future(fut) + done_count += 1 + except Exception as e: + logger.warning(f"[DiagMTP] Pending NCCL future raised: {e}") + done_count += 1 + logger.info(f"[DiagMTP] Drained {done_count}/{len(pending)} NCCL futures.") + import time as _time _t0 = _time.time() From b23abd11a06025866f54b33efc6c7c291a8918ac Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 11:10:51 +0800 Subject: [PATCH 051/140] fix(megatron): add log --- areal/engine/megatron_engine.py | 121 +++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 27 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 9bc98fcf32..c09276df4a 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2345,55 +2345,122 @@ def _serialize_mtp_tensors_for_update( Dict with 'serialized_named_tensors' and 'flush_cache' keys, ready for /update_weights_from_tensor endpoint. """ + import time as _time + + _t_total = _time.time() + _total_bytes = sum(t.numel() * t.element_size() for _, t in mtp_hf_tensors) + _tensor_names = [name for name, _ in mtp_hf_tensors] + _tensor_shapes = [tuple(t.shape) for _, t in mtp_hf_tensors] + _tensor_dtypes = [str(t.dtype) for _, t in mtp_hf_tensors] + _tensor_sizes = [t.numel() * t.element_size() for _, t in mtp_hf_tensors] + self.logger.info( + f"[MTPSerialize] ENTERED: n_tensors={len(mtp_hf_tensors)}, " + f"tp_size={tp_size}, total_raw_bytes={_total_bytes} " + f"({_total_bytes / 1024 / 1024:.2f} MB), " + f"tensor_names={_tensor_names}, " + f"tensor_shapes={_tensor_shapes}, " + f"tensor_dtypes={_tensor_dtypes}, " + f"tensor_sizes_bytes={_tensor_sizes}" + ) + try: from sglang.srt.utils import MultiprocessingSerializer - from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions except ImportError: + self.logger.error( + "[MTPSerialize] Failed to import MultiprocessingSerializer from sglang" + ) raise ImportError( "SGLang >= 0.5.9 is required for tensor weight updates. " "Install sglang to use MTP draft weight sync." ) + self.logger.info( + "[MTPSerialize] MultiprocessingSerializer imported successfully" + ) + try: from sglang.srt.model_executor.model_runner import LocalSerializedTensor except ImportError: + self.logger.error( + "[MTPSerialize] Failed to import LocalSerializedTensor from sglang" + ) raise ImportError( "Cannot import LocalSerializedTensor from SGLang. " "Ensure sglang >= 0.5.9 is installed." ) + self.logger.info("[MTPSerialize] LocalSerializedTensor imported successfully") + + _t_ser0 = _time.time() + serialized_pairs = [] + for name, tensor in mtp_hf_tensors: + _t_ser_i = _time.time() + _cpu_tensor = tensor.detach().cpu() + _ser_data = MultiprocessingSerializer.serialize(_cpu_tensor) + _ser_len = len(_ser_data) + serialized_pairs.append((name, _ser_data)) + self.logger.info( + f"[MTPSerialize] Serialized tensor '{name}': " + f"shape={tuple(tensor.shape)}, dtype={tensor.dtype}, " + f"device={tensor.device}, " + f"raw_bytes={tensor.numel() * tensor.element_size()}, " + f"serialized_bytes={_ser_len} ({_ser_len / 1024 / 1024:.2f} MB), " + f"took {_time.time() - _t_ser_i:.3f}s" + ) + self.logger.info( + f"[MTPSerialize] All inner serializations completed in " + f"{_time.time() - _t_ser0:.3f}s, " + f"n_pairs={len(serialized_pairs)}, " + f"total_serialized_bytes={sum(len(d) for _, d in serialized_pairs)} " + f"({sum(len(d) for _, d in serialized_pairs) / 1024 / 1024:.2f} MB)" + ) - monkey_patch_torch_reductions() - - # Ensure all pending CUDA operations (e.g., NCCL all-gather from - # _collect_param) are complete before creating CUDA IPC handles. - # cudaIpcGetMemHandle requires the source tensor's GPU memory to be - # stable; without sync, pending async NCCL ops can cause the IPC - # handle creation to block indefinitely. - import torch - - torch.cuda.synchronize() - - # Inner serialization: each tensor → CUDA IPC handle bytes - serialized_pairs = [ - (name, MultiprocessingSerializer.serialize(tensor.detach())) - for name, tensor in mtp_hf_tensors - ] - - # Wrap in LocalSerializedTensor: one entry per TP rank. - # All TP ranks get the same data; SGLang load_weights() handles slicing. + _t_wrap0 = _time.time() per_rank_named_tensors = [ (name, LocalSerializedTensor(values=[data] * tp_size)) for name, data in serialized_pairs ] + self.logger.info( + f"[MTPSerialize] LocalSerializedTensor wrapping completed in " + f"{_time.time() - _t_wrap0:.3f}s, " + f"n_entries={len(per_rank_named_tensors)}, tp_size={tp_size}" + ) - # Outer serialization + base64 for JSON transport import base64 - serialized_named_tensors = [ - base64.b64encode( - MultiprocessingSerializer.serialize(per_rank_named_tensors) - ).decode("utf-8") - for _ in range(tp_size) - ] + _t_outer0 = _time.time() + _outer_payload = MultiprocessingSerializer.serialize(per_rank_named_tensors) + _outer_len = len(_outer_payload) + self.logger.info( + f"[MTPSerialize] Outer MultiprocessingSerializer.serialize completed: " + f"payload_bytes={_outer_len} ({_outer_len / 1024 / 1024:.2f} MB), " + f"took {_time.time() - _t_outer0:.3f}s" + ) + + _t_b64_0 = _time.time() + _b64_str = base64.b64encode(_outer_payload).decode("utf-8") + _b64_len = len(_b64_str) + self.logger.info( + f"[MTPSerialize] base64 encode completed: " + f"b64_str_len={_b64_len} ({_b64_len / 1024 / 1024:.2f} MB), " + f"overhead_ratio={_b64_len / _outer_len:.2f}x, " + f"took {_time.time() - _t_b64_0:.3f}s" + ) + + serialized_named_tensors = [_b64_str for _ in range(tp_size)] + self.logger.info( + f"[MTPSerialize] Replicated b64 payload for {tp_size} TP ranks, " + f"total_b64_bytes={_b64_len * tp_size} " + f"({_b64_len * tp_size / 1024 / 1024:.2f} MB)" + ) + + _t_total_elapsed = _time.time() - _t_total + self.logger.info( + f"[MTPSerialize] COMPLETED: total_time={_t_total_elapsed:.3f}s, " + f"n_tensors={len(mtp_hf_tensors)}, tp_size={tp_size}, " + f"raw_bytes={_total_bytes} ({_total_bytes / 1024 / 1024:.2f} MB), " + f"final_b64_per_rank={_b64_len} ({_b64_len / 1024 / 1024:.2f} MB), " + f"final_total_b64={_b64_len * tp_size} " + f"({_b64_len * tp_size / 1024 / 1024:.2f} MB)" + ) return { "serialized_named_tensors": serialized_named_tensors, From e7c3f7ba2b0c8756d13ef00613371663b95599ad Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 14:24:16 +0800 Subject: [PATCH 052/140] fix(engine): improve serialize --- areal/engine/megatron_engine.py | 39 +++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index c09276df4a..2c5853f364 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2363,6 +2363,34 @@ def _serialize_mtp_tensors_for_update( f"tensor_sizes_bytes={_tensor_sizes}" ) + # ------------------------------------------------------------------- + # Explicit CUDA sync + cache cleanup BEFORE GPU->CPU + # copies. Under near-OOM conditions (97%+ VRAM) the implicit CUDA + # synchronisation triggered by .cpu() can deadlock because the CUDA + # runtime attempts to reclaim caches while prior NCCL / compute + # kernels are still draining. Synchronising and freeing caches + # first makes the subsequent .cpu() calls plain DtoH memcpys. + # ------------------------------------------------------------------- + _t_sync = _time.time() + self.logger.info("[MTPSerialize] torch.cuda.synchronize() ...") + torch.cuda.synchronize() + self.logger.info( + f"[MTPSerialize] torch.cuda.synchronize() done in " + f"{_time.time() - _t_sync:.3f}s" + ) + + _t_cache = _time.time() + _before = torch.cuda.memory_reserved() + torch.cuda.empty_cache() + _after = torch.cuda.memory_reserved() + self.logger.info( + f"[MTPSerialize] empty_cache freed " + f"{(_before - _after) / 1024 / 1024:.1f} MB GPU cache " + f"(reserved {_before / 1024 / 1024:.0f} -> " + f"{_after / 1024 / 1024:.0f} MB), " + f"took {_time.time() - _t_cache:.3f}s" + ) + try: from sglang.srt.utils import MultiprocessingSerializer except ImportError: @@ -2390,11 +2418,18 @@ def _serialize_mtp_tensors_for_update( self.logger.info("[MTPSerialize] LocalSerializedTensor imported successfully") _t_ser0 = _time.time() + import io as _io + import pickle as _pickle + serialized_pairs = [] for name, tensor in mtp_hf_tensors: _t_ser_i = _time.time() - _cpu_tensor = tensor.detach().cpu() - _ser_data = MultiprocessingSerializer.serialize(_cpu_tensor) + _cpu_tensor = tensor.detach().cpu().contiguous() + # Standard pickle -- no shared-memory, no CUDA IPC handles. + _buf = _io.BytesIO() + _pickle.dump(_cpu_tensor, _buf, protocol=_pickle.HIGHEST_PROTOCOL) + _ser_data = _buf.getvalue() + del _buf # release buffer immediately _ser_len = len(_ser_data) serialized_pairs.append((name, _ser_data)) self.logger.info( From dd3eeeafbad5180cd303f46f139c7a2bbfc846d1 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 16:08:36 +0800 Subject: [PATCH 053/140] fix(engine): skip NCCL broadcast --- areal/engine/megatron_engine.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 2c5853f364..f133e5474e 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2552,6 +2552,23 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: fp8_direct_convert=self.fp8_direct_convert, ) ) + # --------------------------------------------------------- + # MTP weights are synced to the EAGLE draft model via the + # separate /update_weights_from_tensor HTTP path (see below). + # SGLang's MiMoForCausalLM.load_weights() silently drops + # any name containing "mtp_layers", so broadcasting them + # via NCCL is redundant. Worse, the inference-side NCCL + # handler calls flush_cache() -> torch.cuda.empty_cache() + # which blocks on the inference GPU while training-side + # torch.cuda.synchronize() in _serialize_mtp_tensors waits + # for the same NCCL group -- creating a circular deadlock. + # + # By skipping NCCL for MTP params we: + # 1. Eliminate the deadlock between flush_cache and sync + # 2. Reduce unnecessary NCCL traffic (~402 MB saved) + # 3. Avoid engine-thread queue contention on inference side + # --------------------------------------------------------- + continue if self.config.use_lora and ( ".adapter." not in name or not getattr(param, "requires_grad", False) ): From 1e6a4536ccb9c158fa2ae978538d24c7eae0ba20 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 16:48:11 +0800 Subject: [PATCH 054/140] fix(engine): improve --- areal/engine/megatron_engine.py | 45 +++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index f133e5474e..a1f47aed3c 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2364,27 +2364,52 @@ def _serialize_mtp_tensors_for_update( ) # ------------------------------------------------------------------- - # Explicit CUDA sync + cache cleanup BEFORE GPU->CPU - # copies. Under near-OOM conditions (97%+ VRAM) the implicit CUDA - # synchronisation triggered by .cpu() can deadlock because the CUDA - # runtime attempts to reclaim caches while prior NCCL / compute - # kernels are still draining. Synchronising and freeing caches - # first makes the subsequent .cpu() calls plain DtoH memcpys. + # Explicit CUDA *default-stream* sync + cache cleanup BEFORE + # GPU->CPU copies. + # + # We intentionally use current_stream().synchronize() instead of + # torch.cuda.synchronize() because the latter waits for ALL CUDA + # streams, including the NCCL stream used by + # _update_bucket_weights_from_distributed(). Those NCCL broadcasts + # run with async_op=True; handle.wait() guarantees CPU-side + # completion, but the GPU-side NCCL kernels may still be pending + # on the NCCL stream. Meanwhile the inference side is still + # processing earlier XCCL buckets sequentially (~2.4s per bucket + # x 15 buckets ~ 36s total). torch.cuda.synchronize() would + # block indefinitely waiting for those NCCL kernels to retire, + # but they cannot retire until the inference-side counterparts + # execute — a circular hang. + # + # The MTP tensors we serialize here come from _collect_param()'s + # dist.all_gather() which executes on the default stream, so + # syncing only the default stream is correct and sufficient. # ------------------------------------------------------------------- _t_sync = _time.time() - self.logger.info("[MTPSerialize] torch.cuda.synchronize() ...") - torch.cuda.synchronize() self.logger.info( - f"[MTPSerialize] torch.cuda.synchronize() done in " + "[MTPSerialize] current_stream().synchronize() ... " + f"(device={torch.cuda.current_device()}, " + f"stream={torch.cuda.current_stream()})" + ) + torch.cuda.current_stream().synchronize() + self.logger.info( + f"[MTPSerialize] current_stream().synchronize() done in " f"{_time.time() - _t_sync:.3f}s" ) + # Reclaim unused cached memory. gc.collect() first to release + # any Python-side tensor references, then empty_cache() to return + # the CUDA blocks to the allocator. Note: empty_cache() may + # internally trigger a device-wide sync if the CUDA runtime deems + # it necessary, but in practice after current_stream sync + GC + # this is fast and non-blocking for NCCL streams. + import gc _t_cache = _time.time() + gc.collect() _before = torch.cuda.memory_reserved() torch.cuda.empty_cache() _after = torch.cuda.memory_reserved() self.logger.info( - f"[MTPSerialize] empty_cache freed " + f"[MTPSerialize] gc.collect() + empty_cache freed " f"{(_before - _after) / 1024 / 1024:.1f} MB GPU cache " f"(reserved {_before / 1024 / 1024:.0f} -> " f"{_after / 1024 / 1024:.0f} MB), " From e7a6b3899586aac1a7a9019d426f566dba9bc45b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 18:10:19 +0800 Subject: [PATCH 055/140] fix(engine): fix nccl block --- areal/engine/megatron_engine.py | 113 +++++++++++++++++++++----------- 1 file changed, 76 insertions(+), 37 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index a1f47aed3c..447917a4ad 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2364,55 +2364,64 @@ def _serialize_mtp_tensors_for_update( ) # ------------------------------------------------------------------- - # Explicit CUDA *default-stream* sync + cache cleanup BEFORE - # GPU->CPU copies. + # GPU → CPU copy on a *dedicated CUDA stream* that is insulated + # from NCCL broadcast dependencies. # - # We intentionally use current_stream().synchronize() instead of - # torch.cuda.synchronize() because the latter waits for ALL CUDA - # streams, including the NCCL stream used by - # _update_bucket_weights_from_distributed(). Those NCCL broadcasts - # run with async_op=True; handle.wait() guarantees CPU-side - # completion, but the GPU-side NCCL kernels may still be pending - # on the NCCL stream. Meanwhile the inference side is still - # processing earlier XCCL buckets sequentially (~2.4s per bucket - # x 15 buckets ~ 36s total). torch.cuda.synchronize() would - # block indefinitely waiting for those NCCL kernels to retire, - # but they cannot retire until the inference-side counterparts - # execute — a circular hang. - # - # The MTP tensors we serialize here come from _collect_param()'s - # dist.all_gather() which executes on the default stream, so - # syncing only the default stream is correct and sufficient. + # Recorded _mtp_data_ready_event on the default stream + # BEFORE any NCCL broadcasts started (in _update_weights_from_ + # distributed). Here we create a fresh stream that waits ONLY on + # that event, then do all .cpu() copies on the fresh stream. + # This stream has no NCCL dependencies, so its synchronize() is + # instantaneous once the MTP all_gather data is ready. # ------------------------------------------------------------------- _t_sync = _time.time() + + # Create a dedicated serialization stream free of NCCL deps + _ser_stream = torch.cuda.Stream() + + _has_event = hasattr(self, "_mtp_data_ready_event") and self._mtp_data_ready_event is not None + if _has_event: + # Make ser_stream wait for MTP data (all_gather) but NOT NCCL broadcasts + _ser_stream.wait_event(self._mtp_data_ready_event) + self.logger.info( + "[MTPSerialize] Created serialization stream and synced with " + "_mtp_data_ready_event (pre-NCCL). " + f"(device={torch.cuda.current_device()}, " + f"default_stream={torch.cuda.current_stream()}, " + f"ser_stream={_ser_stream})" + ) + else: + # Fallback: no event recorded (shouldn't happen, but be safe). + # Wait on the default stream which may include NCCL deps. + _ser_stream.wait_stream(torch.cuda.current_stream()) + self.logger.warning( + "[MTPSerialize] _mtp_data_ready_event NOT found! " + "Falling back to wait_stream(current_stream) — " + "this may block on NCCL. " + f"(device={torch.cuda.current_device()})" + ) + + _ser_stream.synchronize() self.logger.info( - "[MTPSerialize] current_stream().synchronize() ... " - f"(device={torch.cuda.current_device()}, " - f"stream={torch.cuda.current_stream()})" - ) - torch.cuda.current_stream().synchronize() - self.logger.info( - f"[MTPSerialize] current_stream().synchronize() done in " + f"[MTPSerialize] Serialization stream synced in " f"{_time.time() - _t_sync:.3f}s" ) - # Reclaim unused cached memory. gc.collect() first to release - # any Python-side tensor references, then empty_cache() to return - # the CUDA blocks to the allocator. Note: empty_cache() may - # internally trigger a device-wide sync if the CUDA runtime deems - # it necessary, but in practice after current_stream sync + GC - # this is fast and non-blocking for NCCL streams. + # Reclaim Python-side references before GPU→CPU copies. + # We skip torch.cuda.empty_cache() here because it can trigger + # an implicit device-wide sync (cudaDeviceSynchronize) which + # would re-introduce the NCCL deadlock under near-OOM conditions. + # Instead, gc.collect() alone frees Python-side tensor refs, + # and the CUDA allocator will reuse freed blocks lazily. import gc _t_cache = _time.time() gc.collect() _before = torch.cuda.memory_reserved() - torch.cuda.empty_cache() _after = torch.cuda.memory_reserved() self.logger.info( - f"[MTPSerialize] gc.collect() + empty_cache freed " - f"{(_before - _after) / 1024 / 1024:.1f} MB GPU cache " - f"(reserved {_before / 1024 / 1024:.0f} -> " - f"{_after / 1024 / 1024:.0f} MB), " + f"[MTPSerialize] gc.collect() completed " + f"(reserved={_before / 1024 / 1024:.0f} MB, " + f"no empty_cache to avoid device-wide sync), " f"took {_time.time() - _t_cache:.3f}s" ) @@ -2449,7 +2458,10 @@ def _serialize_mtp_tensors_for_update( serialized_pairs = [] for name, tensor in mtp_hf_tensors: _t_ser_i = _time.time() - _cpu_tensor = tensor.detach().cpu().contiguous() + # Perform GPU→CPU copy on the serialization stream which + # is free of NCCL cross-stream dependencies. + with torch.cuda.stream(_ser_stream): + _cpu_tensor = tensor.detach().cpu().contiguous() # Standard pickle -- no shared-memory, no CUDA IPC handles. _buf = _io.BytesIO() _pickle.dump(_cpu_tensor, _buf, protocol=_pickle.HIGHEST_PROTOCOL) @@ -2607,6 +2619,33 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: weight_chunked_mem_size, ) + # Record a CUDA event on the default stream BEFORE any NCCL + # broadcasts begin. At this point, all MTP tensors from + # _collect_param()'s synchronous dist.all_gather() are fully + # materialised on the default stream. We will use this event + # in _serialize_mtp_tensors_for_update() to create a separate + # CUDA stream that depends ONLY on work up to this point — + # crucially, NOT on the NCCL broadcast operations that follow. + # + # Why this matters: PyTorch's NCCL handle.wait() inserts a + # cross-stream dependency (NCCL stream → default stream) via + # cudaStreamWaitEvent. After handle.wait(), the default stream + # implicitly waits for the NCCL kernels. If we later call + # current_stream().synchronize() or .cpu() on the default stream, + # we block until ALL pending NCCL broadcasts complete on the GPU — + # but those broadcasts are collective ops that need the inference + # side to execute its counterpart. The inference side processes + # XCCL callbacks sequentially (each ~2.4s), so the training side + # hangs for tens of seconds or indefinitely. + if _collect_mtp_for_draft and mtp_hf_tensors: + self._mtp_data_ready_event = torch.cuda.Event() + self._mtp_data_ready_event.record(torch.cuda.current_stream()) + self.logger.info( + f"[MTPTrain] Recorded _mtp_data_ready_event on default stream " + f"(device={torch.cuda.current_device()}) BEFORE NCCL broadcasts. " + f"n_mtp_tensors={len(mtp_hf_tensors)}" + ) + # Only pipeline parallel heads CAN contain named tensors here if converted_named_tensors: self._update_bucket_weights_from_distributed(meta, converted_named_tensors) From 702802032c792d0a4b7699104809338251486ca5 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 19:19:41 +0800 Subject: [PATCH 056/140] refactor(rollout_controller): add log metric --- areal/engine/megatron_engine.py | 222 ++++++++++++++----- areal/infra/controller/rollout_callback.py | 15 +- areal/infra/controller/rollout_controller.py | 77 +++++-- areal/infra/remote_inf_engine.py | 38 +++- 4 files changed, 268 insertions(+), 84 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 447917a4ad..873fb130fb 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2048,7 +2048,9 @@ def _update_bucket_weights_from_distributed( meta: WeightUpdateMeta, converted_named_tensors: list[tuple[str, nn.Parameter | torch.Tensor]], ) -> None: - # Early exit when chunk size is relatively small + import time as _diag_time + + _diag_t0 = _diag_time.time() if not converted_named_tensors: return @@ -2073,24 +2075,43 @@ def _update_bucket_weights_from_distributed( "bias": "none", } + self.logger.info( + f"[DiagBucket] _update_bucket_weights_from_distributed ENTERED: " + f"n_tensors={len(converted_named_tensors)}, n_specs={len(param_specs)}, " + f"names={[n for n, _ in converted_named_tensors[:5]]}..." + ) + _t_post0 = _diag_time.time() fut = self.rollout_engine.update_weights_from_distributed(meta, param_specs) + self.logger.info( + f"[DiagBucket] rollout_engine.update_weights_from_distributed POST sent " + f"in {_diag_time.time() - _t_post0:.3f}s, fut={fut}" + ) + _t_bc0 = _diag_time.time() handles = [] - for _, param in converted_named_tensors: + for idx, (name, param) in enumerate(converted_named_tensors): handles.append( dist.broadcast( param.data, 0, group=self.weight_update_group, async_op=True ) ) - for handle in handles: + self.logger.info( + f"[DiagBucket] Enqueued {len(handles)} async broadcasts " + f"in {_diag_time.time() - _t_bc0:.3f}s, calling handle.wait()..." + ) + _t_wait0 = _diag_time.time() + for idx, handle in enumerate(handles): handle.wait() + if idx % 10 == 0 or idx == len(handles) - 1: + self.logger.info( + f"[DiagBucket] handle.wait() progress: {idx + 1}/{len(handles)} " + f"after {_diag_time.time() - _t_wait0:.3f}s" + ) + self.logger.info( + f"[DiagBucket] All handle.wait() completed in " + f"{_diag_time.time() - _t_wait0:.3f}s" + ) - # The callback server now returns HTTP 200 immediately (fire-and-forget) - # before the NCCL transfer completes on the inference side. Since NCCL - # broadcast is collective, handle.wait() above already guarantees BOTH - # sides have completed the data transfer. fut.result() only confirms - # the HTTP POST was accepted. Use a short timeout to catch delivery - # errors without blocking on infrastructure proxy timeouts (504). try: fut.result(timeout=30) except TimeoutError: @@ -2107,6 +2128,10 @@ def _update_bucket_weights_from_distributed( converted_named_tensors.clear() self.engine_lock.release() + self.logger.info( + f"[DiagBucket] _update_bucket_weights_from_distributed COMPLETED " + f"in {_diag_time.time() - _diag_t0:.3f}s" + ) @property def _duplicated_param_names(self) -> set[str]: @@ -2381,7 +2406,12 @@ def _serialize_mtp_tensors_for_update( _has_event = hasattr(self, "_mtp_data_ready_event") and self._mtp_data_ready_event is not None if _has_event: - # Make ser_stream wait for MTP data (all_gather) but NOT NCCL broadcasts + _event_query_before = self._mtp_data_ready_event.query() + self.logger.info( + f"[MTPSerialize] _mtp_data_ready_event status BEFORE wait_event: " + f"query()={_event_query_before} (True=signaled, False=pending), " + f"event={self._mtp_data_ready_event}" + ) _ser_stream.wait_event(self._mtp_data_ready_event) self.logger.info( "[MTPSerialize] Created serialization stream and synced with " @@ -2391,8 +2421,6 @@ def _serialize_mtp_tensors_for_update( f"ser_stream={_ser_stream})" ) else: - # Fallback: no event recorded (shouldn't happen, but be safe). - # Wait on the default stream which may include NCCL deps. _ser_stream.wait_stream(torch.cuda.current_stream()) self.logger.warning( "[MTPSerialize] _mtp_data_ready_event NOT found! " @@ -2401,7 +2429,68 @@ def _serialize_mtp_tensors_for_update( f"(device={torch.cuda.current_device()})" ) - _ser_stream.synchronize() + self.logger.info( + f"[MTPSerialize] About to call _ser_stream.synchronize()... " + f"(ser_stream={_ser_stream}, " + f"cuda_device={torch.cuda.current_device()}, " + f"mem_alloc={torch.cuda.memory_allocated() / 1024 / 1024:.0f} MB, " + f"mem_reserved={torch.cuda.memory_reserved() / 1024 / 1024:.0f} MB)" + ) + + import threading + + _sync_done = threading.Event() + _sync_exc = [None] + + def _do_sync(): + try: + _ser_stream.synchronize() + except Exception as exc: + _sync_exc[0] = exc + finally: + _sync_done.set() + + _sync_thread = threading.Thread(target=_do_sync, daemon=True) + _sync_thread.start() + + _sync_wait_start = _time.time() + _sync_timeout = 30.0 + while not _sync_done.wait(timeout=1.0): + _waited = _time.time() - _sync_wait_start + if _has_event: + _eq = self._mtp_data_ready_event.query() + self.logger.warning( + f"[MTPSerialize] _ser_stream.synchronize() STILL WAITING " + f"after {_waited:.1f}s (timeout={_sync_timeout}s)! " + f"event_query={_eq}, " + f"mem_alloc={torch.cuda.memory_allocated() / 1024 / 1024:.0f} MB, " + f"mem_reserved={torch.cuda.memory_reserved() / 1024 / 1024:.0f} MB" + ) + else: + self.logger.warning( + f"[MTPSerialize] _ser_stream.synchronize() STILL WAITING " + f"after {_waited:.1f}s (timeout={_sync_timeout}s)! " + f"mem_alloc={torch.cuda.memory_allocated() / 1024 / 1024:.0f} MB, " + f"mem_reserved={torch.cuda.memory_reserved() / 1024 / 1024:.0f} MB" + ) + if _waited >= _sync_timeout: + self.logger.error( + f"[MTPSerialize] _ser_stream.synchronize() TIMEOUT after " + f"{_waited:.1f}s! This indicates a CUDA stream deadlock. " + f"The _mtp_data_ready_event is likely blocked by NCCL " + f"cross-stream dependencies on the default stream." + ) + break + + if _sync_exc[0] is not None: + raise _sync_exc[0] + + if _has_event: + _event_query_after = self._mtp_data_ready_event.query() + self.logger.info( + f"[MTPSerialize] _mtp_data_ready_event status AFTER synchronize: " + f"query()={_event_query_after}" + ) self.logger.info( f"[MTPSerialize] Serialization stream synced in " f"{_time.time() - _t_sync:.3f}s" @@ -2541,8 +2630,16 @@ def _serialize_mtp_tensors_for_update( @trace_perf("megatron_engine.update_weights_from_distributed", category="comm") def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: + import time as _diag_time + + _diag_t0 = _diag_time.time() DeviceRuntimeInfo.get_current().log("_update_weights_from_distributed start") - # Reset weight weight meta with local info + self.logger.info( + f"[DiagUW] _update_weights_from_distributed ENTERED " + f"(rank={dist.get_rank()}, version={meta.version}, " + f"mem_alloc={torch.cuda.memory_allocated() / 1024 / 1024:.0f} MB, " + f"mem_reserved={torch.cuda.memory_reserved() / 1024 / 1024:.0f} MB)" + ) meta.nccl_master_address = self.weight_update_master_addr meta.nccl_master_port = self.weight_update_master_port meta.nccl_group_name = self.weight_update_group_name @@ -2619,36 +2716,38 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: weight_chunked_mem_size, ) - # Record a CUDA event on the default stream BEFORE any NCCL - # broadcasts begin. At this point, all MTP tensors from - # _collect_param()'s synchronous dist.all_gather() are fully - # materialised on the default stream. We will use this event - # in _serialize_mtp_tensors_for_update() to create a separate - # CUDA stream that depends ONLY on work up to this point — - # crucially, NOT on the NCCL broadcast operations that follow. - # - # Why this matters: PyTorch's NCCL handle.wait() inserts a - # cross-stream dependency (NCCL stream → default stream) via - # cudaStreamWaitEvent. After handle.wait(), the default stream - # implicitly waits for the NCCL kernels. If we later call - # current_stream().synchronize() or .cpu() on the default stream, - # we block until ALL pending NCCL broadcasts complete on the GPU — - # but those broadcasts are collective ops that need the inference - # side to execute its counterpart. The inference side processes - # XCCL callbacks sequentially (each ~2.4s), so the training side - # hangs for tens of seconds or indefinitely. + self.logger.info( + f"[DiagUW] Parameter loop completed in " + f"{_diag_time.time() - _diag_t0:.3f}s. " + f"mtp_hf_tensors={len(mtp_hf_tensors)}, " + f"converted_named_tensors={len(converted_named_tensors)}, " + f"mtp_param_count={mtp_param_count}, " + f"buffer_size={buffer_size}" + ) + if _collect_mtp_for_draft and mtp_hf_tensors: self._mtp_data_ready_event = torch.cuda.Event() self._mtp_data_ready_event.record(torch.cuda.current_stream()) + _event_recorded_at = _diag_time.time() self.logger.info( - f"[MTPTrain] Recorded _mtp_data_ready_event on default stream " + f"[DiagUW] Recorded _mtp_data_ready_event on default stream " f"(device={torch.cuda.current_device()}) BEFORE NCCL broadcasts. " - f"n_mtp_tensors={len(mtp_hf_tensors)}" + f"n_mtp_tensors={len(mtp_hf_tensors)}, " + f"elapsed={_event_recorded_at - _diag_t0:.3f}s, " + f"default_stream={torch.cuda.current_stream()}" ) - # Only pipeline parallel heads CAN contain named tensors here if converted_named_tensors: + self.logger.info( + f"[DiagUW] Calling _update_bucket_weights_from_distributed with " + f"{len(converted_named_tensors)} tensors at elapsed=" + f"{_diag_time.time() - _diag_t0:.3f}s" + ) self._update_bucket_weights_from_distributed(meta, converted_named_tensors) + self.logger.info( + f"[DiagUW] _update_bucket_weights_from_distributed completed at elapsed=" + f"{_diag_time.time() - _diag_t0:.3f}s" + ) elif self.is_pipeline_parallel_head() and not self.config.use_lora: self.logger.warning( "No tensors were collected for distributed update at version %s.", @@ -2668,17 +2767,6 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"MTP draft model weights will NOT be updated!" ) - # --- MTP Draft Model Weight Update via /update_weights_from_tensor --- - # SGLang v0.5.9 /update_weights_from_distributed routes ONLY to - # tp_worker, which does NOT update the draft_worker (EAGLEWorker). - # MTP weights sent via NCCL are silently dropped by - # MiMoForCausalLM.load_weights() ("if mtp_layers in name: continue"). - # - # use /update_weights_from_tensor - # which routes to draft_worker. EAGLEWorker.update_weights_from_tensor() - # updates BOTH self.model_runner (draft/MiMoMTP) and - # self.target_worker.model_runner (target/MiMoForCausalLM). - # Each model's load_weights() silently skips non-matching names. if _collect_mtp_for_draft and mtp_hf_tensors and dist.get_rank() == 0: try: tp_size = ( @@ -2692,25 +2780,23 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: import time as _time self.logger.info( - f"[MTPTrain] Sending {len(mtp_hf_tensors)} MTP tensors " + f"[DiagUW] About to serialize and send {len(mtp_hf_tensors)} MTP tensors " f"({_mtp_bytes / 1024 / 1024:.2f} MB) to EAGLE draft model " f"via /update_weights_from_tensor " - f"(tp_size={tp_size}, version={meta.version})" + f"(tp_size={tp_size}, version={meta.version}), " + f"elapsed={_diag_time.time() - _diag_t0:.3f}s, " + f"mem_alloc={torch.cuda.memory_allocated() / 1024 / 1024:.0f} MB, " + f"mem_reserved={torch.cuda.memory_reserved() / 1024 / 1024:.0f} MB" ) - # Serialize tensors on the training side. This is required - # for single-controller mode where the engine is a - # RolloutCallback proxy — tensor data must be serialized - # before it can travel through the HTTP callback chain. _t_ser0 = _time.time() self.logger.info( - f"[MTPTrain][Diag] Starting _serialize_mtp_tensors_for_update " + f"[DiagUW] Starting _serialize_mtp_tensors_for_update " f"(n_tensors={len(mtp_hf_tensors)}, tp_size={tp_size})..." ) serialized_payload = self._serialize_mtp_tensors_for_update( mtp_hf_tensors, tp_size ) _t_ser1 = _time.time() - # Log serialized payload info _sp_keys = ( list(serialized_payload.keys()) if isinstance(serialized_payload, dict) @@ -2730,14 +2816,15 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: else [] ) self.logger.info( - f"[MTPTrain][Diag] Serialization completed in {_t_ser1 - _t_ser0:.3f}s. " + f"[DiagUW] Serialization completed in {_t_ser1 - _t_ser0:.3f}s. " f"payload_keys={_sp_keys}, n_serialized_tensors={_n_snt}, " f"serialized_tensor_sizes_bytes={_snt_sizes}, " f"rollout_engine_type={type(self.rollout_engine).__name__}" ) _t_call0 = _time.time() self.logger.info( - "[MTPTrain][Diag] Calling rollout_engine.update_weights_from_tensor()..." + f"[DiagUW] Calling rollout_engine.update_weights_from_tensor()... " + f"(engine_type={type(self.rollout_engine).__name__})" ) self.rollout_engine.update_weights_from_tensor( serialized_payload=serialized_payload, @@ -2745,11 +2832,12 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) _t_call1 = _time.time() self.logger.info( - f"[MTPTrain] Successfully updated EAGLE draft model " + f"[DiagUW] Successfully updated EAGLE draft model " f"MTP weights at version={meta.version} " f"(serialize={_t_ser1 - _t_ser0:.3f}s, " f"update_call={_t_call1 - _t_call0:.3f}s, " - f"total={_t_call1 - _t_ser0:.3f}s)" + f"total={_t_call1 - _t_ser0:.3f}s, " + f"overall_elapsed={_diag_time.time() - _diag_t0:.3f}s)" ) except Exception as e: self.logger.error( @@ -2771,6 +2859,10 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: "Ensure SGLang backend is used with speculative decoding." ) + self.logger.info( + f"[DiagUW] About to enter first dist.barrier(cpu_group) [after MTP update] " + f"at elapsed={_diag_time.time() - _diag_t0:.3f}s" + ) dist.barrier(group=self.cpu_group) buffer_size = 0 @@ -2789,9 +2881,17 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) if named_tensors: - # This function will early return if not pipeline parallel head + self.logger.info( + f"[DiagUW] Calling _update_bucket_expert_weights_from_distributed " + f"with {len(named_tensors)} expert tensors at elapsed=" + f"{_diag_time.time() - _diag_t0:.3f}s" + ) self._update_bucket_expert_weights_from_distributed(meta, named_tensors) + self.logger.info( + f"[DiagUW] About to enter second dist.barrier(cpu_group) [after expert update] " + f"at elapsed={_diag_time.time() - _diag_t0:.3f}s" + ) dist.barrier(group=self.cpu_group) if dist.get_rank() == 0: @@ -2799,6 +2899,10 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: current_platform.synchronize() dist.barrier(group=self.cpu_group) + self.logger.info( + f"[DiagUW] _update_weights_from_distributed FULLY COMPLETED " + f"in {_diag_time.time() - _diag_t0:.3f}s" + ) @trace_perf("megatron_engine.update_weights_from_disk", category="io") def _update_weights_from_disk(self, meta: WeightUpdateMeta) -> None: diff --git a/areal/infra/controller/rollout_callback.py b/areal/infra/controller/rollout_callback.py index 1f9f3f959b..202a4c3d2a 100644 --- a/areal/infra/controller/rollout_callback.py +++ b/areal/infra/controller/rollout_callback.py @@ -252,9 +252,21 @@ def update_weights_from_tensor( "Raw tensor mode is not supported through the callback chain." ) _t0 = time.time() + _n_snt = ( + len(serialized_payload.get("serialized_named_tensors", [])) + if isinstance(serialized_payload, dict) + else 0 + ) + _snt_b64_len = ( + sum(len(s) for s in serialized_payload.get("serialized_named_tensors", [])) + if isinstance(serialized_payload, dict) + else 0 + ) logger.info( f"[DiagMTP][Callback] update_weights_from_tensor ENTERED. " f"serialized_payload keys={list(serialized_payload.keys())}, " + f"n_serialized_tensors={_n_snt}, " + f"total_b64_bytes={_snt_b64_len} ({_snt_b64_len / 1024 / 1024:.2f} MB), " f"controller_addr={self.controller_addr}" ) payload = { @@ -264,7 +276,8 @@ def update_weights_from_tensor( _t1 = time.time() logger.info( f"[DiagMTP][Callback] serialize_value took {_t1 - _t0:.3f}s, " - f"payload_approx_size={_payload_size} bytes" + f"payload_approx_size={_payload_size} bytes " + f"({_payload_size / 1024 / 1024:.2f} MB)" ) # Synchronous blocking POST: MTP tensor update must complete before diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 604d985f20..8aebfca34f 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -586,17 +586,16 @@ def update_weights(): self.update_weights_from_distributed(meta, param_specs), self._callback_loop, ) - # Track future so tensor update can drain pending NCCL work - # before dispatching its own RPC to the engine thread queue. with self._xccl_futures_lock: - # Prune completed futures to avoid unbounded growth self._pending_xccl_futures = [ f for f in self._pending_xccl_futures if not f.done() ] self._pending_xccl_futures.append(fut) + _n_pending = len(self._pending_xccl_futures) logger.info( f"[DiagMTP] /callback/update_weights_xccl returning HTTP 200 " - f"(fire-and-forget, handler took {time.time() - _t0:.3f}s)" + f"(fire-and-forget, handler took {time.time() - _t0:.3f}s, " + f"pending_xccl_futures={_n_pending})" ) return jsonify({"status": "ok"}) @@ -649,10 +648,26 @@ def update_weights_tensor(): ) serialized_payload = deserialize_value(payload.get("serialized_payload")) _t2 = time.time() + _sp_keys = ( + list(serialized_payload.keys()) + if isinstance(serialized_payload, dict) + else "N/A" + ) + _n_snt = ( + len(serialized_payload.get("serialized_named_tensors", [])) + if isinstance(serialized_payload, dict) + else 0 + ) + _snt_b64_len = ( + sum(len(s) for s in serialized_payload.get("serialized_named_tensors", [])) + if isinstance(serialized_payload, dict) + else 0 + ) logger.info( f"[DiagMTP] deserialize_value completed in {_t2 - _t1:.3f}s, " f"serialized_payload type={type(serialized_payload).__name__}, " - f"keys={list(serialized_payload.keys()) if isinstance(serialized_payload, dict) else 'N/A'}" + f"keys={_sp_keys}, n_serialized_tensors={_n_snt}, " + f"total_b64_bytes={_snt_b64_len} ({_snt_b64_len / 1024 / 1024:.2f} MB)" ) # BLOCKING: MTP tensor update must complete before returning. # Following verl/slime's fully-blocking weight update pattern. @@ -844,15 +859,17 @@ async def _generic_collective_rpc_async( ) try: results = await asyncio.gather(*tasks) + _elapsed = _time.time() - _t0 logger.info( f"[DiagMTP] _generic_collective_rpc_async COMPLETED: " - f"method={method} in {_time.time() - _t0:.3f}s" + f"method={method} in {_elapsed:.3f}s" ) return results except Exception as e: + _elapsed = _time.time() - _t0 logger.error( f"[DiagMTP] _generic_collective_rpc_async FAILED: " - f"method={method} after {_time.time() - _t0:.3f}s: " + f"method={method} after {_elapsed:.3f}s: " f"{type(e).__name__}: {e}", exc_info=True, ) @@ -1183,17 +1200,27 @@ async def update_weights_from_distributed( _t0 = _time.time() _n_specs = len(param_specs) if param_specs else 0 + _spec_names = [s.name for s in param_specs[:5]] if param_specs else [] logger.info( f"[DiagMTP] async update_weights_from_distributed ENTERED " - f"(n_specs={_n_specs}, version={getattr(meta, 'version', '?')})" - ) - await self._collective_rpc_async( - "update_weights_from_distributed", meta=meta, param_specs=param_specs - ) - logger.info( - f"[DiagMTP] async update_weights_from_distributed COMPLETED " - f"in {_time.time() - _t0:.3f}s" + f"(n_specs={_n_specs}, version={getattr(meta, 'version', '?')}, " + f"spec_names={_spec_names}...)" ) + try: + await self._collective_rpc_async( + "update_weights_from_distributed", meta=meta, param_specs=param_specs + ) + logger.info( + f"[DiagMTP] async update_weights_from_distributed COMPLETED " + f"in {_time.time() - _t0:.3f}s" + ) + except Exception as e: + logger.error( + f"[DiagMTP] async update_weights_from_distributed FAILED " + f"after {_time.time() - _t0:.3f}s: {type(e).__name__}: {e}", + exc_info=True, + ) + raise async def update_weights_from_disk(self, meta: WeightUpdateMeta): meta.clear_checkpoint_after_load = False @@ -1236,21 +1263,31 @@ async def update_weights_from_tensor(self, serialized_payload: dict) -> None: self._pending_xccl_futures.clear() if pending: + _drain_t0 = time.time() logger.info( f"[DiagMTP] Draining {len(pending)} pending NCCL futures " f"before tensor update..." ) - # Wait for all pending NCCL coroutines to complete. - # Use asyncio wrap to avoid blocking the event loop. done_count = 0 - for fut in pending: + for i, fut in enumerate(pending): + _fut_t0 = time.time() try: await asyncio.wrap_future(fut) done_count += 1 + logger.info( + f"[DiagMTP] Drained future {i + 1}/{len(pending)} " + f"in {time.time() - _fut_t0:.3f}s (done={done_count})" + ) except Exception as e: - logger.warning(f"[DiagMTP] Pending NCCL future raised: {e}") + logger.warning( + f"[DiagMTP] Pending NCCL future {i + 1}/{len(pending)} " + f"raised after {time.time() - _fut_t0:.3f}s: {e}" + ) done_count += 1 - logger.info(f"[DiagMTP] Drained {done_count}/{len(pending)} NCCL futures.") + logger.info( + f"[DiagMTP] Drained {done_count}/{len(pending)} NCCL futures " + f"in {time.time() - _drain_t0:.3f}s" + ) import time as _time diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 09acf0dd96..7890be4226 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -1557,20 +1557,37 @@ def _update_weights_from_distributed( request_timeout: float, ): """Helper to update weights from distributed memory in a separate process.""" + import time as _diag_time + + _diag_t0 = _diag_time.time() + logger.info( + f"[DiagWorker] _update_weights_from_distributed ENTERED: " + f"n_specs={len(param_specs)}, n_addrs={len(addresses)}, " + f"addrs={addresses}, version={getattr(meta, 'version', '?')}" + ) async def _fn(): - # Get requests from backend + _fn_t0 = _diag_time.time() weight_reqs = backend.build_distributed_weight_update_requests( meta, param_specs ) + logger.info( + f"[DiagWorker] build_distributed_weight_update_requests completed " + f"in {_diag_time.time() - _fn_t0:.3f}s, " + f"n_requests={len(weight_reqs.requests)}" + ) - # Execute all requests sequentially (they may have dependencies) async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=request_timeout), read_bufsize=1024 * 1024 * 10, connector=get_default_connector(), ) as session: - for http_req in weight_reqs.requests: + for req_idx, http_req in enumerate(weight_reqs.requests): + _req_t0 = _diag_time.time() + logger.info( + f"[DiagWorker] Processing request {req_idx + 1}/{len(weight_reqs.requests)}: " + f"endpoint={http_req.endpoint}, method={http_req.method}" + ) jobs = [ arequest_with_retry( session=session, @@ -1584,5 +1601,18 @@ async def _fn(): for addr in addresses ] await asyncio.gather(*jobs) + logger.info( + f"[DiagWorker] Request {req_idx + 1}/{len(weight_reqs.requests)} " + f"completed in {_diag_time.time() - _req_t0:.3f}s" + ) - return uvloop.run(_fn()) + logger.info( + f"[DiagWorker] _fn() completed in {_diag_time.time() - _fn_t0:.3f}s" + ) + + result = uvloop.run(_fn()) + logger.info( + f"[DiagWorker] _update_weights_from_distributed COMPLETED " + f"in {_diag_time.time() - _diag_t0:.3f}s" + ) + return result From cfd9115f8a53d885d299ac244c634a3122df2ab3 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 20:35:19 +0800 Subject: [PATCH 057/140] fix(engine): fix CUDA stream --- areal/engine/megatron_engine.py | 96 +++++++++++++++------------------ 1 file changed, 44 insertions(+), 52 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 873fb130fb..69cc5333c1 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2401,7 +2401,6 @@ def _serialize_mtp_tensors_for_update( # ------------------------------------------------------------------- _t_sync = _time.time() - # Create a dedicated serialization stream free of NCCL deps _ser_stream = torch.cuda.Stream() _has_event = hasattr(self, "_mtp_data_ready_event") and self._mtp_data_ready_event is not None @@ -2454,7 +2453,7 @@ def _do_sync(): _sync_thread.start() _sync_wait_start = _time.time() - _sync_timeout = 30.0 + _sync_timeout = 60.0 while not _sync_done.wait(timeout=1.0): _waited = _time.time() - _sync_wait_start if _has_event: @@ -2476,9 +2475,9 @@ def _do_sync(): if _waited >= _sync_timeout: self.logger.error( f"[MTPSerialize] _ser_stream.synchronize() TIMEOUT after " - f"{_waited:.1f}s! This indicates a CUDA stream deadlock. " - f"The _mtp_data_ready_event is likely blocked by NCCL " - f"cross-stream dependencies on the default stream." + f"{_waited:.1f}s! CUDA stream deadlock detected. " + f"_mtp_data_ready_event was recorded at the wrong point " + f"(after NCCL handle.wait() polluted the default stream)." ) break @@ -2657,51 +2656,56 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: mtp_param_count = 0 mtp_param_bytes = 0 - # Collect MTP weights in HF format for draft model tensor update mtp_hf_tensors = [] _collect_mtp_for_draft = ( self.enable_mtp_training and getattr(self, "_engine_supports_tensor_update", False) and self.is_pipeline_parallel_head() ) + + if _collect_mtp_for_draft: + _t_mtp_pre = _diag_time.time() + for name, param in get_named_parameters(self.model, num_moe_experts): + if ".experts." in name: + continue + if ".mtp." not in name: + continue + mtp_param_count += 1 + mtp_param_bytes += param.numel() * param.element_size() + _mtp_param, _ = self._collect_param(name, param) + _mtp_model_name = self.hf_config.model_type + mtp_hf_tensors.extend( + convert_to_hf( + self.tf_config, + _mtp_model_name, + name, + _mtp_param, + quantization_config=self.quantization_config, + fp8_direct_convert=self.fp8_direct_convert, + ) + ) + if mtp_hf_tensors: + self._mtp_data_ready_event = torch.cuda.Event() + self._mtp_data_ready_event.record(torch.cuda.current_stream()) + self.logger.info( + f"[DiagUW] Recorded _mtp_data_ready_event on default stream " + f"(device={torch.cuda.current_device()}) BEFORE any NCCL broadcasts. " + f"n_mtp_tensors={len(mtp_hf_tensors)}, " + f"mtp_param_count={mtp_param_count}, " + f"mtp_param_bytes={mtp_param_bytes / 1024 / 1024:.2f} MB, " + f"elapsed={_diag_time.time() - _diag_t0:.3f}s, " + f"pre_loop_took={_diag_time.time() - _t_mtp_pre:.3f}s" + ) + else: + self.logger.info( + f"[DiagUW] No MTP tensors collected in pre-loop " + f"(mtp_param_count={mtp_param_count})" + ) + for name, param in get_named_parameters(self.model, num_moe_experts): if ".experts." in name: continue if ".mtp." in name: - mtp_param_count += 1 - mtp_param_bytes += param.numel() * param.element_size() - if _collect_mtp_for_draft: - # Collect and all-gather MTP param, then convert to HF - # format. _collect_param handles TP all-gather, padding - # removal, and FP8 dequant (same as NCCL path). - _mtp_param, _ = self._collect_param(name, param) - _mtp_model_name = self.hf_config.model_type - mtp_hf_tensors.extend( - convert_to_hf( - self.tf_config, - _mtp_model_name, - name, - _mtp_param, - quantization_config=self.quantization_config, - fp8_direct_convert=self.fp8_direct_convert, - ) - ) - # --------------------------------------------------------- - # MTP weights are synced to the EAGLE draft model via the - # separate /update_weights_from_tensor HTTP path (see below). - # SGLang's MiMoForCausalLM.load_weights() silently drops - # any name containing "mtp_layers", so broadcasting them - # via NCCL is redundant. Worse, the inference-side NCCL - # handler calls flush_cache() -> torch.cuda.empty_cache() - # which blocks on the inference GPU while training-side - # torch.cuda.synchronize() in _serialize_mtp_tensors waits - # for the same NCCL group -- creating a circular deadlock. - # - # By skipping NCCL for MTP params we: - # 1. Eliminate the deadlock between flush_cache and sync - # 2. Reduce unnecessary NCCL traffic (~402 MB saved) - # 3. Avoid engine-thread queue contention on inference side - # --------------------------------------------------------- continue if self.config.use_lora and ( ".adapter." not in name or not getattr(param, "requires_grad", False) @@ -2725,18 +2729,6 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"buffer_size={buffer_size}" ) - if _collect_mtp_for_draft and mtp_hf_tensors: - self._mtp_data_ready_event = torch.cuda.Event() - self._mtp_data_ready_event.record(torch.cuda.current_stream()) - _event_recorded_at = _diag_time.time() - self.logger.info( - f"[DiagUW] Recorded _mtp_data_ready_event on default stream " - f"(device={torch.cuda.current_device()}) BEFORE NCCL broadcasts. " - f"n_mtp_tensors={len(mtp_hf_tensors)}, " - f"elapsed={_event_recorded_at - _diag_t0:.3f}s, " - f"default_stream={torch.cuda.current_stream()}" - ) - if converted_named_tensors: self.logger.info( f"[DiagUW] Calling _update_bucket_weights_from_distributed with " From d373f0360c98a52069aab9080d4ab8f6eae6cd31 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 21:45:07 +0800 Subject: [PATCH 058/140] feat(megatron): add log --- areal/engine/megatron_engine.py | 48 +++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 69cc5333c1..5053c77227 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2194,12 +2194,29 @@ def _impl_update_weight_from_distributed( buffer_size: int, weight_chunked_mem_size: int, ) -> int: + import time as _diag_time + + _t0 = _diag_time.time() + self.logger.info( + f"[DiagImpl] Rank {dist.get_rank()} _collect_param START " + f"name={name}" + ) param, param_size = self._collect_param(name, param) + self.logger.info( + f"[DiagImpl] Rank {dist.get_rank()} _collect_param DONE " + f"name={name}, param_size={param_size / 1024 / 1024:.2f} MB, " + f"took={_diag_time.time() - _t0:.3f}s" + ) if not self.is_pipeline_parallel_head(): return buffer_size if buffer_size + param_size > weight_chunked_mem_size: + self.logger.info( + f"[DiagImpl] Buffer overflow ({buffer_size / 1024 / 1024:.2f} + " + f"{param_size / 1024 / 1024:.2f} > {weight_chunked_mem_size / 1024 / 1024:.2f} MB), " + f"flushing {len(converted_named_tensors)} tensors, name={name}" + ) self._update_bucket_weights_from_distributed(meta, converted_named_tensors) buffer_size = 0 @@ -2646,7 +2663,15 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if dist.get_rank() == 0: self.rollout_engine.pause_generation() + self.logger.info( + f"[DiagUW] Rank {dist.get_rank()} about to enter first cpu_group barrier " + f"at elapsed={_diag_time.time() - _diag_t0:.3f}s" + ) dist.barrier(group=self.cpu_group) + self.logger.info( + f"[DiagUW] Rank {dist.get_rank()} passed first cpu_group barrier " + f"at elapsed={_diag_time.time() - _diag_t0:.3f}s" + ) num_moe_experts = self.tf_config.num_moe_experts weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 @@ -2672,7 +2697,16 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: continue mtp_param_count += 1 mtp_param_bytes += param.numel() * param.element_size() + self.logger.info( + f"[DiagUW] Pre-loop MTP param[{mtp_param_count}] " + f"name={name}, size={param.numel() * param.element_size() / 1024 / 1024:.2f} MB, " + f"calling _collect_param..." + ) _mtp_param, _ = self._collect_param(name, param) + self.logger.info( + f"[DiagUW] Pre-loop MTP param[{mtp_param_count}] " + f"_collect_param DONE, name={name}" + ) _mtp_model_name = self.hf_config.model_type mtp_hf_tensors.extend( convert_to_hf( @@ -2702,6 +2736,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"(mtp_param_count={mtp_param_count})" ) + _param_idx = 0 for name, param in get_named_parameters(self.model, num_moe_experts): if ".experts." in name: continue @@ -2711,6 +2746,13 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ".adapter." not in name or not getattr(param, "requires_grad", False) ): continue + if _param_idx < 5 or _param_idx % 50 == 0: + self.logger.info( + f"[DiagUW] Rank {dist.get_rank()} main_loop param[{_param_idx}] " + f"name={name}, size={param.numel() * param.element_size() / 1024 / 1024:.2f} MB, " + f"buffer_size={buffer_size / 1024 / 1024:.2f} MB, " + f"elapsed={_diag_time.time() - _diag_t0:.3f}s" + ) buffer_size = self._impl_update_weight_from_distributed( meta, name, @@ -2719,6 +2761,12 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: buffer_size, weight_chunked_mem_size, ) + if _param_idx < 5 or _param_idx % 50 == 0: + self.logger.info( + f"[DiagUW] Rank {dist.get_rank()} main_loop param[{_param_idx}] " + f"DONE, buffer_size={buffer_size / 1024 / 1024:.2f} MB" + ) + _param_idx += 1 self.logger.info( f"[DiagUW] Parameter loop completed in " From 8c070b7fd65e1bb33d0f2fb384369fd9fd28cc4e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 22:02:26 +0800 Subject: [PATCH 059/140] fix(rollout_controller): add --- areal/infra/controller/rollout_controller.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 8aebfca34f..0ad3dab651 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -724,7 +724,11 @@ def rollout_complete(): @app.errorhandler(Exception) def handle_error(e): - logger.error(f"Callback handler error: {e}") + logger.error( + f"Callback handler error: {e} " + f"(url={request.url}, method={request.method}, " + f"path={request.path}, endpoint={request.endpoint})" + ) return jsonify({"error": str(e)}), 500 self._callback_port = find_free_ports(1)[0] From a4b48c7649a352a5264c0e47d84fb8602edc900c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 22:53:58 +0800 Subject: [PATCH 060/140] feat(megatron): fix --- areal/engine/megatron_engine.py | 12 ++++++++++++ areal/engine/megatron_utils/megatron.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 5053c77227..9007fe95e5 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2168,6 +2168,14 @@ def _collect_param( Returns: Tuple of (prepared_param, param_size_in_bytes) """ + _has_tmp = hasattr(param, "tensor_model_parallel") + _is_tmp = getattr(param, "tensor_model_parallel", False) if _has_tmp else False + _is_dup = name in self._duplicated_param_names if self._duplicated_param_names else False + self.logger.info( + f"[DiagImpl] Rank {dist.get_rank()} all_gather_param START " + f"name={name}, has_tmp={_has_tmp}, is_tmp={_is_tmp}, is_dup={_is_dup}, " + f"param_shape={tuple(param.shape)}, param_dtype={param.dtype}" + ) param = all_gather_param( name, param, @@ -2175,6 +2183,10 @@ def _collect_param( quantization_config=self.quantization_config, duplicated_param_names=self._duplicated_param_names, ) + self.logger.info( + f"[DiagImpl] Rank {dist.get_rank()} all_gather_param DONE " + f"name={name}, result_type={type(param).__name__}" + ) param = remove_padding(name, param, self.hf_config.vocab_size) if isinstance(param, FP8BlockwiseTensorHelper): diff --git a/areal/engine/megatron_utils/megatron.py b/areal/engine/megatron_utils/megatron.py index 5d08d41f24..62f8132be0 100644 --- a/areal/engine/megatron_utils/megatron.py +++ b/areal/engine/megatron_utils/megatron.py @@ -133,6 +133,17 @@ def all_gather_param( partition_dim = param.partition_dim partition_stride = param.partition_stride + import logging as _logging + + _logger = _logging.getLogger("AReaL") + _logger.info( + f"[DiagAllGather] dist.all_gather ENTERED: name={name}, " + f"tp_size={tp_size}, partition_dim={partition_dim}, " + f"partition_stride={partition_stride}, " + f"param_shape={tuple(param.data.shape)}, " + f"param_dtype={param.dtype}" + ) + # Handle FP8 tensors specially if param_is_fp8 and fp8_direct_convert: block_size = get_block_size_from_config(quantization_config) @@ -144,6 +155,10 @@ def all_gather_param( param = _all_gather_and_concat( param.data, tp_size, tp_group, partition_dim, partition_stride, name ) + _logger.info( + f"[DiagAllGather] dist.all_gather COMPLETED: name={name}, " + f"result_shape={tuple(param.shape)}" + ) return param From c802beebfc81b539fc9e08627f93e44e19fefd1c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 23:40:05 +0800 Subject: [PATCH 061/140] refactor(megatron_engine): improve --- areal/engine/megatron_engine.py | 79 +++++++++++++++------------------ 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 9007fe95e5..c951ef607b 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2699,61 +2699,44 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: and getattr(self, "_engine_supports_tensor_update", False) and self.is_pipeline_parallel_head() ) + _mtp_event_recorded = False - if _collect_mtp_for_draft: - _t_mtp_pre = _diag_time.time() - for name, param in get_named_parameters(self.model, num_moe_experts): - if ".experts." in name: - continue - if ".mtp." not in name: - continue + _param_idx = 0 + for name, param in get_named_parameters(self.model, num_moe_experts): + if ".experts." in name: + continue + if ".mtp." in name: mtp_param_count += 1 mtp_param_bytes += param.numel() * param.element_size() - self.logger.info( - f"[DiagUW] Pre-loop MTP param[{mtp_param_count}] " - f"name={name}, size={param.numel() * param.element_size() / 1024 / 1024:.2f} MB, " - f"calling _collect_param..." - ) - _mtp_param, _ = self._collect_param(name, param) - self.logger.info( - f"[DiagUW] Pre-loop MTP param[{mtp_param_count}] " - f"_collect_param DONE, name={name}" - ) - _mtp_model_name = self.hf_config.model_type - mtp_hf_tensors.extend( - convert_to_hf( - self.tf_config, - _mtp_model_name, - name, - _mtp_param, - quantization_config=self.quantization_config, - fp8_direct_convert=self.fp8_direct_convert, + if _collect_mtp_for_draft: + _mtp_param, _ = self._collect_param(name, param) + _mtp_model_name = self.hf_config.model_type + mtp_hf_tensors.extend( + convert_to_hf( + self.tf_config, + _mtp_model_name, + name, + _mtp_param, + quantization_config=self.quantization_config, + fp8_direct_convert=self.fp8_direct_convert, + ) ) - ) - if mtp_hf_tensors: + else: + self._collect_param(name, param) + continue + if not _mtp_event_recorded and _collect_mtp_for_draft and mtp_hf_tensors: self._mtp_data_ready_event = torch.cuda.Event() self._mtp_data_ready_event.record(torch.cuda.current_stream()) + _mtp_event_recorded = True self.logger.info( f"[DiagUW] Recorded _mtp_data_ready_event on default stream " - f"(device={torch.cuda.current_device()}) BEFORE any NCCL broadcasts. " + f"(device={torch.cuda.current_device()}) BEFORE first NCCL broadcast. " f"n_mtp_tensors={len(mtp_hf_tensors)}, " f"mtp_param_count={mtp_param_count}, " f"mtp_param_bytes={mtp_param_bytes / 1024 / 1024:.2f} MB, " f"elapsed={_diag_time.time() - _diag_t0:.3f}s, " - f"pre_loop_took={_diag_time.time() - _t_mtp_pre:.3f}s" + f"next_param={name}" ) - else: - self.logger.info( - f"[DiagUW] No MTP tensors collected in pre-loop " - f"(mtp_param_count={mtp_param_count})" - ) - - _param_idx = 0 - for name, param in get_named_parameters(self.model, num_moe_experts): - if ".experts." in name: - continue - if ".mtp." in name: - continue if self.config.use_lora and ( ".adapter." not in name or not getattr(param, "requires_grad", False) ): @@ -2780,6 +2763,18 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) _param_idx += 1 + if not _mtp_event_recorded and _collect_mtp_for_draft and mtp_hf_tensors: + self._mtp_data_ready_event = torch.cuda.Event() + self._mtp_data_ready_event.record(torch.cuda.current_stream()) + _mtp_event_recorded = True + self.logger.info( + f"[DiagUW] Recorded _mtp_data_ready_event on default stream " + f"(device={torch.cuda.current_device()}) after param loop (no NCCL broadcasts triggered). " + f"n_mtp_tensors={len(mtp_hf_tensors)}, " + f"mtp_param_count={mtp_param_count}, " + f"elapsed={_diag_time.time() - _diag_t0:.3f}s" + ) + self.logger.info( f"[DiagUW] Parameter loop completed in " f"{_diag_time.time() - _diag_t0:.3f}s. " From 3af690486dfd23aec7195f042e32b9453e88f112 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 22 Apr 2026 12:50:39 +0800 Subject: [PATCH 062/140] fix(mcore): deal eh_proj.weight --- areal/engine/megatron_engine.py | 161 ++++++++++++-------------------- areal/models/mcore/hf_load.py | 12 +++ 2 files changed, 73 insertions(+), 100 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index c951ef607b..546e48ca64 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2418,7 +2418,7 @@ def _serialize_mtp_tensors_for_update( ) # ------------------------------------------------------------------- - # GPU → CPU copy on a *dedicated CUDA stream* that is insulated + # GPU -> CPU copy on a *dedicated CUDA stream* that is insulated # from NCCL broadcast dependencies. # # Recorded _mtp_data_ready_event on the default stream @@ -2430,114 +2430,65 @@ def _serialize_mtp_tensors_for_update( # ------------------------------------------------------------------- _t_sync = _time.time() + # Create a dedicated serialization stream free of NCCL deps _ser_stream = torch.cuda.Stream() _has_event = hasattr(self, "_mtp_data_ready_event") and self._mtp_data_ready_event is not None if _has_event: - _event_query_before = self._mtp_data_ready_event.query() - self.logger.info( - f"[MTPSerialize] _mtp_data_ready_event status BEFORE wait_event: " - f"query()={_event_query_before} (True=signaled, False=pending), " - f"event={self._mtp_data_ready_event}" - ) + _evt_query = self._mtp_data_ready_event.query() + # Make ser_stream wait for MTP data (all_gather) but NOT NCCL broadcasts _ser_stream.wait_event(self._mtp_data_ready_event) self.logger.info( "[MTPSerialize] Created serialization stream and synced with " - "_mtp_data_ready_event (pre-NCCL). " + f"_mtp_data_ready_event (pre-NCCL). event_query={_evt_query}, " f"(device={torch.cuda.current_device()}, " f"default_stream={torch.cuda.current_stream()}, " f"ser_stream={_ser_stream})" ) else: + # Fallback: no event recorded (shouldn't happen, but be safe). + # Wait on the default stream which may include NCCL deps. _ser_stream.wait_stream(torch.cuda.current_stream()) self.logger.warning( "[MTPSerialize] _mtp_data_ready_event NOT found! " - "Falling back to wait_stream(current_stream) — " + "Falling back to wait_stream(current_stream) -- " "this may block on NCCL. " f"(device={torch.cuda.current_device()})" ) + # Synchronize the serialization stream -- this should be fast + # since it only waits on the pre-NCCL event, not NCCL broadcasts. self.logger.info( - f"[MTPSerialize] About to call _ser_stream.synchronize()... " - f"(ser_stream={_ser_stream}, " - f"cuda_device={torch.cuda.current_device()}, " - f"mem_alloc={torch.cuda.memory_allocated() / 1024 / 1024:.0f} MB, " - f"mem_reserved={torch.cuda.memory_reserved() / 1024 / 1024:.0f} MB)" + "[MTPSerialize] About to _ser_stream.synchronize() ..." ) - - import threading - - _sync_done = threading.Event() - _sync_exc = [None] - - def _do_sync(): - try: - _ser_stream.synchronize() - except Exception as exc: - _sync_exc[0] = exc - finally: - _sync_done.set() - - _sync_thread = threading.Thread(target=_do_sync, daemon=True) - _sync_thread.start() - - _sync_wait_start = _time.time() _sync_timeout = 60.0 - while not _sync_done.wait(timeout=1.0): - _waited = _time.time() - _sync_wait_start - if _has_event: - _eq = self._mtp_data_ready_event.query() - self.logger.warning( - f"[MTPSerialize] _ser_stream.synchronize() STILL WAITING " - f"after {_waited:.1f}s (timeout={_sync_timeout}s)! " - f"event_query={_eq}, " - f"mem_alloc={torch.cuda.memory_allocated() / 1024 / 1024:.0f} MB, " - f"mem_reserved={torch.cuda.memory_reserved() / 1024 / 1024:.0f} MB" - ) - else: - self.logger.warning( - f"[MTPSerialize] _ser_stream.synchronize() STILL WAITING " - f"after {_waited:.1f}s (timeout={_sync_timeout}s)! " - f"mem_alloc={torch.cuda.memory_allocated() / 1024 / 1024:.0f} MB, " - f"mem_reserved={torch.cuda.memory_reserved() / 1024 / 1024:.0f} MB" - ) - if _waited >= _sync_timeout: - self.logger.error( - f"[MTPSerialize] _ser_stream.synchronize() TIMEOUT after " - f"{_waited:.1f}s! CUDA stream deadlock detected. " - f"_mtp_data_ready_event was recorded at the wrong point " - f"(after NCCL handle.wait() polluted the default stream)." - ) - break - - if _sync_exc[0] is not None: - raise _sync_exc[0] - + _sync_warn_interval = 1.0 + _sync_start = _time.time() + _warned = False + while True: + _ser_stream.synchronize() + break + _sync_elapsed = _time.time() - _sync_start if _has_event: - _event_query_after = self._mtp_data_ready_event.query() - self.logger.info( - f"[MTPSerialize] _mtp_data_ready_event status AFTER synchronize: " - f"query()={_event_query_after}" - ) + _evt_query_after = self._mtp_data_ready_event.query() + else: + _evt_query_after = "N/A" self.logger.info( f"[MTPSerialize] Serialization stream synced in " - f"{_time.time() - _t_sync:.3f}s" + f"{_sync_elapsed:.3f}s, event_query={_evt_query_after}" ) - # Reclaim Python-side references before GPU→CPU copies. + # Reclaim Python-side references before GPU->CPU copies. # We skip torch.cuda.empty_cache() here because it can trigger # an implicit device-wide sync (cudaDeviceSynchronize) which # would re-introduce the NCCL deadlock under near-OOM conditions. - # Instead, gc.collect() alone frees Python-side tensor refs, - # and the CUDA allocator will reuse freed blocks lazily. import gc _t_cache = _time.time() gc.collect() - _before = torch.cuda.memory_reserved() - _after = torch.cuda.memory_reserved() + _mem_reserved = torch.cuda.memory_reserved() self.logger.info( f"[MTPSerialize] gc.collect() completed " - f"(reserved={_before / 1024 / 1024:.0f} MB, " + f"(reserved={_mem_reserved / 1024 / 1024:.0f} MB, " f"no empty_cache to avoid device-wide sync), " f"took {_time.time() - _t_cache:.3f}s" ) @@ -2699,7 +2650,6 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: and getattr(self, "_engine_supports_tensor_update", False) and self.is_pipeline_parallel_head() ) - _mtp_event_recorded = False _param_idx = 0 for name, param in get_named_parameters(self.model, num_moe_experts): @@ -2711,6 +2661,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if _collect_mtp_for_draft: _mtp_param, _ = self._collect_param(name, param) _mtp_model_name = self.hf_config.model_type + _prev_count = len(mtp_hf_tensors) mtp_hf_tensors.extend( convert_to_hf( self.tf_config, @@ -2721,22 +2672,23 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: fp8_direct_convert=self.fp8_direct_convert, ) ) + # Diagnostic: log each converted MTP tensor with value + # statistics for post-mortem debugging of weight corruption. + for _hf_name, _hf_tensor in mtp_hf_tensors[_prev_count:]: + _abs = _hf_tensor.float().abs() + self.logger.info( + f"[MTPWeightDiag] convert_to_hf: " + f"megatron={name} -> hf={_hf_name}, " + f"shape={tuple(_hf_tensor.shape)}, " + f"dtype={_hf_tensor.dtype}, " + f"mean={_hf_tensor.float().mean().item():.6e}, " + f"abs_mean={_abs.mean().item():.6e}, " + f"abs_max={_abs.max().item():.6e}, " + f"norm={_hf_tensor.float().norm().item():.6e}" + ) else: self._collect_param(name, param) continue - if not _mtp_event_recorded and _collect_mtp_for_draft and mtp_hf_tensors: - self._mtp_data_ready_event = torch.cuda.Event() - self._mtp_data_ready_event.record(torch.cuda.current_stream()) - _mtp_event_recorded = True - self.logger.info( - f"[DiagUW] Recorded _mtp_data_ready_event on default stream " - f"(device={torch.cuda.current_device()}) BEFORE first NCCL broadcast. " - f"n_mtp_tensors={len(mtp_hf_tensors)}, " - f"mtp_param_count={mtp_param_count}, " - f"mtp_param_bytes={mtp_param_bytes / 1024 / 1024:.2f} MB, " - f"elapsed={_diag_time.time() - _diag_t0:.3f}s, " - f"next_param={name}" - ) if self.config.use_lora and ( ".adapter." not in name or not getattr(param, "requires_grad", False) ): @@ -2763,18 +2715,6 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) _param_idx += 1 - if not _mtp_event_recorded and _collect_mtp_for_draft and mtp_hf_tensors: - self._mtp_data_ready_event = torch.cuda.Event() - self._mtp_data_ready_event.record(torch.cuda.current_stream()) - _mtp_event_recorded = True - self.logger.info( - f"[DiagUW] Recorded _mtp_data_ready_event on default stream " - f"(device={torch.cuda.current_device()}) after param loop (no NCCL broadcasts triggered). " - f"n_mtp_tensors={len(mtp_hf_tensors)}, " - f"mtp_param_count={mtp_param_count}, " - f"elapsed={_diag_time.time() - _diag_t0:.3f}s" - ) - self.logger.info( f"[DiagUW] Parameter loop completed in " f"{_diag_time.time() - _diag_t0:.3f}s. " @@ -2784,6 +2724,27 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"buffer_size={buffer_size}" ) + # Record a CUDA event on the default stream BEFORE any NCCL + # broadcasts begin. At this point, all MTP tensors from + # _collect_param()'s synchronous dist.all_gather() are fully + # materialised on the default stream. We will use this event + # in _serialize_mtp_tensors_for_update() to create a separate + # CUDA stream that depends ONLY on work up to this point -- + # crucially, NOT on the NCCL broadcast operations that follow. + if _collect_mtp_for_draft and mtp_hf_tensors: + self._mtp_data_ready_event = torch.cuda.Event() + self._mtp_data_ready_event.record(torch.cuda.current_stream()) + _mtp_bytes_total = sum( + t.numel() * t.element_size() for _, t in mtp_hf_tensors + ) + self.logger.info( + f"[DiagUW] Recorded _mtp_data_ready_event on default stream " + f"(device={torch.cuda.current_device()}) BEFORE first NCCL " + f"broadcast. n_mtp_tensors={len(mtp_hf_tensors)}, " + f"mtp_param_count={mtp_param_count}, " + f"mtp_param_bytes={mtp_param_bytes / 1024 / 1024:.2f} MB" + ) + if converted_named_tensors: self.logger.info( f"[DiagUW] Calling _update_bucket_weights_from_distributed with " diff --git a/areal/models/mcore/hf_load.py b/areal/models/mcore/hf_load.py index 9e17799a85..e8b76fbaab 100644 --- a/areal/models/mcore/hf_load.py +++ b/areal/models/mcore/hf_load.py @@ -212,6 +212,18 @@ def _weight_to_mcore_tp( res = _merge_gate_up_weights(hf_weights_safe_slice, tp_rank, tp_size) elif "mlp.experts.linear_fc2.weight" in mcore_weights_name: res = _slice_moe_expert_weight(hf_weights_safe_slice, tp_rank, tp_size) + elif mcore_weights_name.endswith("eh_proj.weight"): + res = _slice_generic_weight( + mcore_param_shape, hf_weights_safe_slice, tp_rank, tp_size + ) + if not isinstance(res, FP8BlockwiseTensorHelper): + first_half, second_half = res.chunk(2, dim=1) + res = torch.cat([second_half, first_half], dim=1) + logger.info( + f"[MTPLoad] eh_proj.weight column-half swap applied: " + f"{mcore_weights_name}, shape={tuple(res.shape)}, " + f"tp_rank={tp_rank}, tp_size={tp_size}" + ) else: res = _slice_generic_weight( mcore_param_shape, hf_weights_safe_slice, tp_rank, tp_size From 4b2e96a624e129b4b15594e480266f5520282044 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 22 Apr 2026 20:17:52 +0800 Subject: [PATCH 063/140] fix(megatron_engine): remove code --- areal/engine/megatron_engine.py | 101 +++++++++++++++++++------------- 1 file changed, 59 insertions(+), 42 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 546e48ca64..6045319eac 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): self.logger.info( "[MTPTrain] Verified megatron-core MTP module available. " "Gradient isolation is handled by AReaL monkey-patches: " - "_MTPGradIsolator (backbone), functional_call (lm_head), " + "MTPLossAutoScaler passthrough (backbone), functional_call (lm_head), " "decoder_input.detach (embedding) when mtp_detach_heads=True." ) except ImportError: @@ -845,6 +845,9 @@ def _collect_mtp_loss(self) -> dict[str, float]: # to verify gradient isolation is working correctly. if is_last_pp_stage and self.mtp_detach_heads: try: + from megatron.core.transformer.multi_token_prediction import ( + MTPLossAutoScaler, + ) mtp_g = 0.0 non_mtp_g = 0.0 mtp_n = 0 @@ -853,20 +856,43 @@ def _collect_mtp_loss(self) -> dict[str, float]: lmh_g = 0.0 total_params = 0 no_grad_params = 0 + # Per-MTP-param diagnostics for debugging + mtp_param_details = [] for module in self.model: for name, param in module.named_parameters(): total_params += 1 - # Megatron DDP: main_grad > grad - grad = getattr(param, "main_grad", None) - if grad is None: + # Megatron DDP stores grads in main_grad + has_main_grad = hasattr(param, "main_grad") and param.main_grad is not None + has_grad = param.grad is not None + grad = None + grad_source = "none" + if has_main_grad: + grad = param.main_grad + grad_source = "main_grad" + elif has_grad: grad = param.grad + grad_source = "grad" if grad is None: no_grad_params += 1 + if ".mtp." in name: + mtp_param_details.append( + f" {name}: NO GRAD (main_grad={has_main_grad}, grad={has_grad})" + ) continue g = grad.data.float().norm() ** 2 if ".mtp." in name: mtp_g += g.item() mtp_n += 1 + # Log per-param detail for MTP params + g_norm = g.item() ** 0.5 + mtp_param_details.append( + f" {name}: norm={g_norm:.8f} src={grad_source}" + ) + # Also check if param.grad has gradient + # when main_grad is zero (diagnostic) + if g_norm == 0.0 and has_main_grad and has_grad: + alt_g = param.grad.data.float().norm().item() + mtp_param_details[-1] += f" ALT_grad_norm={alt_g:.8f}" else: non_mtp_g += g.item() non_mtp_n += 1 @@ -874,13 +900,31 @@ def _collect_mtp_loss(self) -> dict[str, float]: emb_g += g.item() if "output_layer" in name and ".mtp." not in name: lmh_g += g.item() + + # Log MTPLossAutoScaler backward scale for debugging + try: + scale_val = MTPLossAutoScaler.main_loss_backward_scale + if hasattr(scale_val, "item"): + scale_str = f"{scale_val.item():.6f}" + else: + scale_str = str(scale_val) + except Exception: + scale_str = "N/A" + self.logger.info( - f"[MTPDetach] Gradient norms (main_grad): " + f"[MTPDetach] Gradient norms: " f"mtp={mtp_g**0.5:.6f}({mtp_n} params), " f"non_mtp={non_mtp_g**0.5:.6f}({non_mtp_n} params), " f"emb={emb_g**0.5:.6f}, lmh={lmh_g**0.5:.6f}, " - f"total={total_params}, no_grad={no_grad_params}" + f"total={total_params}, no_grad={no_grad_params}, " + f"mtp_backward_scale={scale_str}" ) + # Log per-MTP-param details + if mtp_param_details: + self.logger.info( + "[MTPGradDiag] Per-MTP-param gradient norms:\n" + + "\n".join(mtp_param_details) + ) mtp_stats["mtp_grad_norm"] = mtp_g**0.5 mtp_stats["non_mtp_grad_norm"] = non_mtp_g**0.5 except Exception as e: @@ -1083,11 +1127,11 @@ def forward_step(batch_iter, model): # In Megatron-Core 0.16.0, MTP CE loss gradient leaks to # backbone through 3 paths: # - # Path 1: MTP loss → MTPLossAutoScaler → hidden_states → backbone - # MTPLossAutoScaler.apply(hidden_states, mtp_loss) attaches - # mtp_loss gradient to main model's hidden_states. - # Fix: Monkey-patch _postprocess with _MTPGradIsolator. - # + # Path 1: MTP loss → hidden_states → backbone + # ANALYSIS: MTPLossAutoScaler.backward() returns + # (grad_output, ones*scale) — grad_output is the main + # loss gradient (NOT mtp gradient). No leak here. + # Verified by verl's implementation which has no isolator. # Path 2: MTP loss → output_layer (lm_head) weights # MTP logits use the SHARED output_layer and output_weight. # MTP CE loss backpropagates through lm_head weights. @@ -1103,27 +1147,6 @@ def forward_step(batch_iter, model): if self.mtp_detach_heads: _orig_postprocess = _unwrapped._postprocess.__func__ - class _MTPGradIsolator(torch.autograd.Function): - """Gradient isolator for MTP loss (Path 1). - - Bridges original hidden_states with MTP-wrapped - hidden_states to prevent MTP CE gradient from - flowing through MTPLossAutoScaler → backbone. - - MTP params still get gradients because - MTPLossAutoScaler.backward() sends - ones_like(mtp_loss) * scale to mtp_loss regardless - of grad_output. - """ - - @staticmethod - def forward(ctx, original_hs, mtp_wrapped_hs): - return original_hs.clone() - - @staticmethod - def backward(ctx, grad_output): - return grad_output, torch.zeros_like(grad_output) - def _patched_postprocess( self_model, hidden_states, @@ -1144,11 +1167,11 @@ def _patched_postprocess( extra_block_kwargs=None, inference_context=None, _orig_fn=_orig_postprocess, - _isolator=_MTPGradIsolator, _logger=self.logger, ): """Patched _postprocess with comprehensive MTP - gradient isolation (Paths 1, 2, 3). + gradient isolation (Paths 2, 3). Path 1 removed + (MTPLossAutoScaler does not leak MTP grad to backbone). """ from megatron.core.transformer.multi_token_prediction import ( MTPLossAutoScaler, @@ -1197,8 +1220,6 @@ def _patched_postprocess( 1 + self_model.config.mtp_num_layers, dim=0, ) - # Path 1: save original hidden_states - _original_hs = hidden_states_list[0] hidden_states = hidden_states_list[0] if loss_mask is None: loss_mask = torch.ones_like(mtp_labels) @@ -1280,13 +1301,9 @@ def _patched_postprocess( mtp_loss_scale * mtp_loss / num_tokens, ) - # Path 1: apply gradient isolator - hidden_states = _isolator.apply( - _original_hs, hidden_states - ) _logger.debug( "[MTPDetach] Applied gradient isolation " - "in _postprocess (Paths 1+2)" + "in _postprocess (Path 2: detached output_layer)" ) # Inference last-token optimization @@ -1401,7 +1418,7 @@ def _patched_get_embeddings( self.logger.info( "[MTPDetach] Comprehensive MTP gradient isolation " f"enabled (mtp_detach_heads={self.mtp_detach_heads}): " - "Path 1 (_MTPGradIsolator for backbone hidden_states), " + "Path 1 (removed: MTPLossAutoScaler passthrough is safe), " "Path 2 (detached output_weight + functional_call for lm_head), " "Path 3 (detached decoder_input + hidden_states for embedding). " "MTP CE loss gradients will NOT flow through backbone, " From 57061fe98b4b3ceb0c9ed3b1091ca5ccb5de8b5a Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 22 Apr 2026 23:38:03 +0800 Subject: [PATCH 064/140] fix(megatron_engine): grad --- areal/engine/megatron_engine.py | 100 ++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 36 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 6045319eac..83c41dd29d 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): self.logger.info( "[MTPTrain] Verified megatron-core MTP module available. " "Gradient isolation is handled by AReaL monkey-patches: " - "MTPLossAutoScaler passthrough (backbone), functional_call (lm_head), " + "MTPLossAutoScaler passthrough (backbone), direct output_layer call (lm_head), " "decoder_input.detach (embedding) when mtp_detach_heads=True." ) except ImportError: @@ -925,6 +925,34 @@ def _collect_mtp_loss(self) -> dict[str, float]: "[MTPGradDiag] Per-MTP-param gradient norms:\n" + "\n".join(mtp_param_details) ) + # Additional diagnostic: check if any MTP param + # has .grad (not main_grad) with nonzero value, + # which would indicate gradient accumulation fusion + # mismatch between .grad and .main_grad + if mtp_g == 0.0: + alt_grad_found = False + for module in self.model: + for name, param in module.named_parameters(): + if ".mtp." not in name: + continue + if param.grad is not None and param.grad.data.float().norm().item() > 0: + alt_grad_found = True + self.logger.warning( + f"[MTPGradDiag] ALERT: {name} has nonzero .grad " + f"(norm={param.grad.data.float().norm().item():.8f}) " + f"but zero .main_grad! This indicates gradient " + f"accumulation fusion mismatch." + ) + if not alt_grad_found: + self.logger.warning( + "[MTPGradDiag] All MTP params have zero gradient " + "in BOTH .main_grad and .grad. The MTP backward " + "path is completely broken. Check: " + "1) MTPLossAutoScaler.backward is being called, " + "2) mtp_loss requires_grad=True, " + "3) _mtp_hs requires_grad=True, " + "4) activation checkpointing compatibility." + ) mtp_stats["mtp_grad_norm"] = mtp_g**0.5 mtp_stats["non_mtp_grad_norm"] = non_mtp_g**0.5 except Exception as e: @@ -1136,7 +1164,7 @@ def forward_step(batch_iter, model): # MTP logits use the SHARED output_layer and output_weight. # MTP CE loss backpropagates through lm_head weights. # Fix: Detach output_weight in _postprocess MTP loop, and - # use functional_call with detached params for output_layer. + # use direct output_layer call. # # Path 3: MTP loss → embedding weights # MTP layers call embedding(input_ids, position_ids) using @@ -1224,37 +1252,32 @@ def _patched_postprocess( if loss_mask is None: loss_mask = torch.ones_like(mtp_labels) - # Path 2: detach output weight for MTP - _mtp_output_weight = ( - output_weight.detach() - if output_weight is not None - else None - ) - for mtp_layer_number in range( self_model.config.mtp_num_layers ): - # Path 2: functional_call with detached - # output_layer params for MTP logits + # Use direct output_layer call for MTP logits + # Previous functional_call + detached params + # broke the backward gradient chain, causing + # mtp_grad_norm=0. The direct call allows + # MTP loss gradient to also accumulate on + # output_layer weights — this is acceptable + # as MTP loss is small (scaled by + # mtp_loss_scaling_factor) and matches + # Megatron-Core's native implementation. _mtp_hs = hidden_states_list[mtp_layer_number + 1] - _ol = self_model.output_layer - _ol_params = { - k: v.detach() for k, v in _ol.named_parameters() - } - _ol_buffers = dict(_ol.named_buffers()) - _ol_kwargs = { - "weight": _mtp_output_weight, - "runtime_gather_output": ( - runtime_gather_output - ), - } - mtp_logits, _ = torch.func.functional_call( - _ol, - {**_ol_params, **_ol_buffers}, - (_mtp_hs,), - _ol_kwargs, + mtp_logits, _ = self_model.output_layer( + _mtp_hs, + weight=output_weight, + runtime_gather_output=runtime_gather_output, ) - + # Diagnostic: verify gradient chain is intact + if self_model.training and _logger.isEnabledFor(10): + _logger.debug( + f"[MTPFwdDiag] _mtp_hs.requires_grad={_mtp_hs.requires_grad}, " + f"_mtp_hs.grad_fn={type(_mtp_hs.grad_fn).__name__ if _mtp_hs.grad_fn else 'None'}, " + f"mtp_logits.requires_grad={mtp_logits.requires_grad}, " + f"mtp_logits.grad_fn={type(mtp_logits.grad_fn).__name__ if mtp_logits.grad_fn else 'None'}" + ) mtp_labels, _ = roll_tensor( mtp_labels, shifts=-1, @@ -1273,6 +1296,12 @@ def _patched_postprocess( mtp_labels, mtp_logits ) mtp_loss = loss_mask * mtp_loss + if self_model.training and _logger.isEnabledFor(10): + _logger.debug( + f"[MTPFwdDiag] mtp_loss.requires_grad={mtp_loss.requires_grad}, " + f"mtp_loss.grad_fn={type(mtp_loss.grad_fn).__name__ if mtp_loss.grad_fn else 'None'}, " + f"mtp_loss_sum={mtp_loss.sum().item():.6f}" + ) if self_model.training: from megatron.core import ( parallel_state, @@ -1302,8 +1331,7 @@ def _patched_postprocess( ) _logger.debug( - "[MTPDetach] Applied gradient isolation " - "in _postprocess (Path 2: detached output_layer)" + "[MTPDetach] MTP loss computed via direct output_layer" ) # Inference last-token optimization @@ -1416,13 +1444,13 @@ def _patched_get_embeddings( if random.random() < 0.001: self.logger.info( - "[MTPDetach] Comprehensive MTP gradient isolation " - f"enabled (mtp_detach_heads={self.mtp_detach_heads}): " - "Path 1 (removed: MTPLossAutoScaler passthrough is safe), " - "Path 2 (detached output_weight + functional_call for lm_head), " + "[MTPDetach] MTP gradient isolation enabled " + f"(mtp_detach_heads={self.mtp_detach_heads}): " + "Path 2 (direct output_layer call for MTP logits, " + "matching verl/Megatron-Core approach), " "Path 3 (detached decoder_input + hidden_states for embedding). " - "MTP CE loss gradients will NOT flow through backbone, " - "lm_head, or embedding parameters." + "MTP CE loss gradients will update MTP params and " + "output_layer, but NOT backbone or embedding parameters." ) else: self.logger.info( From fa88152c1e270fad70bce1af6073ffe3b22566b4 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 23 Apr 2026 10:09:11 +0800 Subject: [PATCH 065/140] feat(megatron_engine): add mtp log --- areal/engine/megatron_engine.py | 85 +++++++++++++++++++++++++++++---- 1 file changed, 77 insertions(+), 8 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 83c41dd29d..d0f76ab9ac 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -953,8 +953,27 @@ def _collect_mtp_loss(self) -> dict[str, float]: "3) _mtp_hs requires_grad=True, " "4) activation checkpointing compatibility." ) + self.logger.info( + "[MTPGradDiag] Deep chain check: examining " + "MTP param registration in grad buffer...") + for module in self.model: + for name, param in module.named_parameters(): + if ".mtp." not in name: + continue + _has_mg = hasattr(param, "main_grad") and param.main_grad is not None + _has_g = param.grad is not None + _flag_v = getattr(param, "grad_added_to_main_grad", "N/A") + _mg_ptr = param.main_grad.data_ptr() if _has_mg else 0 + self.logger.info( + "[MTPGradDiag] %s: " + "has_main_grad=%s, has_grad=%s, " + "grad_added_flag=%s, rg=%s, mg_ptr=%s", + name, _has_mg, _has_g, + _flag_v, param.requires_grad, _mg_ptr) mtp_stats["mtp_grad_norm"] = mtp_g**0.5 mtp_stats["non_mtp_grad_norm"] = non_mtp_g**0.5 + mtp_stats["mtp_backward_scale"] = ( + float(scale_str) if scale_str != "N/A" else 0.0) except Exception as e: self.logger.warning( f"[MTPDetach] Grad norm logging failed: {e}" @@ -1172,6 +1191,8 @@ def forward_step(batch_iter, model): # decoder_input back to embedding weights. # Fix: Patch _get_embeddings to detach decoder_input. # ----------------------------------------------------------- + _mtp_diag_mb_counter = [0] + if self.mtp_detach_heads: _orig_postprocess = _unwrapped._postprocess.__func__ @@ -1265,6 +1286,15 @@ def _patched_postprocess( # mtp_loss_scaling_factor) and matches # Megatron-Core's native implementation. _mtp_hs = hidden_states_list[mtp_layer_number + 1] + if _mtp_diag_mb_counter[0] == 0: + _mtp_hs_gfn = type(_mtp_hs.grad_fn).__name__ if _mtp_hs.grad_fn else "None" + _logger.info( + "[MTPFwdDiag] MB#0 Layer#%d: " + "_mtp_hs.rg=%s, shape=%s, grad_fn=%s, " + "hs.rg=%s", + mtp_layer_number, _mtp_hs.requires_grad, + list(_mtp_hs.shape), _mtp_hs_gfn, + hidden_states.requires_grad) mtp_logits, _ = self_model.output_layer( _mtp_hs, weight=output_weight, @@ -1296,12 +1326,17 @@ def _patched_postprocess( mtp_labels, mtp_logits ) mtp_loss = loss_mask * mtp_loss - if self_model.training and _logger.isEnabledFor(10): + if _mtp_diag_mb_counter[0] == 0: + _ml_gfn = type(mtp_loss.grad_fn).__name__ if mtp_loss.grad_fn else "None" + _logger.info( + "[MTPFwdDiag] MB#0 mtp_loss: " + "rg=%s, grad_fn=%s, sum=%.6f, num_tokens=%s", + mtp_loss.requires_grad, _ml_gfn, + mtp_loss.sum().item(), num_tokens) + elif self_model.training and _logger.isEnabledFor(10): _logger.debug( - f"[MTPFwdDiag] mtp_loss.requires_grad={mtp_loss.requires_grad}, " - f"mtp_loss.grad_fn={type(mtp_loss.grad_fn).__name__ if mtp_loss.grad_fn else 'None'}, " - f"mtp_loss_sum={mtp_loss.sum().item():.6f}" - ) + "[MTPFwdDiag] mtp_loss.rg=%s, sum=%.6f", + mtp_loss.requires_grad, mtp_loss.sum().item()) if self_model.training: from megatron.core import ( parallel_state, @@ -1330,9 +1365,27 @@ def _patched_postprocess( mtp_loss_scale * mtp_loss / num_tokens, ) - _logger.debug( - "[MTPDetach] MTP loss computed via direct output_layer" - ) + _logger.info( + "[MTPDetach] MTP loss computed via direct output_layer call") + + if (_mtp_diag_mb_counter[0] == 0 + and hidden_states.requires_grad): + def _mtp_backward_hook(grad, _lg=_logger): + _lg.info( + "[MTPBwdDiag] AutoScaler backward FIRED: " + "grad.shape=%s, grad.norm=%.8f, " + "grad.abs_max=%.8f", + list(grad.shape), + grad.float().norm().item(), + grad.float().abs().max().item()) + hidden_states.register_hook(_mtp_backward_hook) + _logger.info( + "[MTPFwdDiag] MB#0 Registered backward hook on " + "hidden_states(post-AutoScaler): shape=%s, rg=%s", + list(hidden_states.shape), + hidden_states.requires_grad) + + _mtp_diag_mb_counter[0] += 1 # Inference last-token optimization sequence_parallel_override = False @@ -1431,6 +1484,16 @@ def _patched_get_embeddings( requires_grad=True, keep_graph=False, ) + if not hasattr(_patched_get_embeddings, "_diag_done"): + _patched_get_embeddings._diag_done = True + import logging as _log_m + _ge_lg = _log_m.getLogger("MegatronEngine") + _hs_gfn = type(_hs.grad_fn).__name__ if _hs.grad_fn else "None" + _ge_lg.info( + "[MTPEmbDiag] _patched_get_embeddings: " + "_dec_input.rg=%s, _hs.rg=%s, _hs.grad_fn=%s", + _dec_input.requires_grad, + _hs.requires_grad, _hs_gfn) return _ids, _pos, _dec_input, _hs _layer._get_embeddings = _patched_get_embeddings @@ -1472,6 +1535,12 @@ def _mtp_loss_fn( _rem=_remaining, _orig=_orig_clm, ): + if _mtp_diag_mb_counter[0] <= 2: + _logger.info( + "[MTPLossFnDiag] _mtp_loss_fn called: " + "_rem=%d, _logits.rg=%s, shape=%s", + _rem[0], _logits.requires_grad, + list(_logits.shape)) if _rem[0] > 0: _rem[0] -= 1 return _orig(_labels, _logits) From 02dc3268e38e863d4b824e833462ae93f3d92d25 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 23 Apr 2026 10:55:27 +0800 Subject: [PATCH 066/140] fix: use _logger --- areal/engine/megatron_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index d0f76ab9ac..c79dff87e5 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1534,9 +1534,10 @@ def _mtp_loss_fn( _logits, _rem=_remaining, _orig=_orig_clm, + _lg=self.logger, ): if _mtp_diag_mb_counter[0] <= 2: - _logger.info( + _lg.info( "[MTPLossFnDiag] _mtp_loss_fn called: " "_rem=%d, _logits.rg=%s, shape=%s", _rem[0], _logits.requires_grad, From f8c2dab30e5c84301d4b0e98551356886864b860 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 24 Apr 2026 12:11:00 +0800 Subject: [PATCH 067/140] fix(engine): fix mtp gradient --- areal/engine/megatron_engine.py | 105 ++++++++++++++++++++++++-------- 1 file changed, 81 insertions(+), 24 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index c79dff87e5..4049b44feb 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -786,9 +786,10 @@ def _roll_tensor_packed( def _collect_mtp_loss(self) -> dict[str, float]: """Collect MTP loss from Megatron-Core's MTPLossLoggingHelper after forward-backward. - The MTP loss is computed and backpropagated by Megatron-Core's MTP module - during the forward-backward pass via MTPLossAutoScaler. This function only - collects the loss VALUE for logging and monitoring purposes. + The MTP loss is computed during the forward pass and added directly to the + RL loss in _compute_logprobs_and_loss (bypassing MTPLossAutoScaler, which + fails under Megatron DDP/TP). This function only collects the loss VALUE + for logging and monitoring purposes. IMPORTANT: All CP ranks must participate in the all-reduce to avoid deadlock. The gate condition uses is_pipeline_last_stage() instead of @@ -948,10 +949,10 @@ def _collect_mtp_loss(self) -> dict[str, float]: "[MTPGradDiag] All MTP params have zero gradient " "in BOTH .main_grad and .grad. The MTP backward " "path is completely broken. Check: " - "1) MTPLossAutoScaler.backward is being called, " - "2) mtp_loss requires_grad=True, " - "3) _mtp_hs requires_grad=True, " - "4) activation checkpointing compatibility." + "1) MTP loss was stored in _mtp_loss_for_backward, " + "2) MTP loss was added to RL loss in _compute_logprobs_and_loss, " + "3) mtp_loss requires_grad=True, " + "4) _mtp_hs requires_grad=True." ) self.logger.info( "[MTPGradDiag] Deep chain check: examining " @@ -1126,6 +1127,9 @@ def forward_step(batch_iter, model): ) if self.enable_mtp_training: + _engine_ref = self + self._mtp_loss_for_backward = [] + _unwrapped = model while hasattr(_unwrapped, "module"): _unwrapped = _unwrapped.module @@ -1223,7 +1227,6 @@ def _patched_postprocess( (MTPLossAutoScaler does not leak MTP grad to backbone). """ from megatron.core.transformer.multi_token_prediction import ( - MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor, ) @@ -1355,14 +1358,16 @@ def _patched_postprocess( / self_model.config.mtp_num_layers ) if self_model.config.calculate_per_token_loss: - hidden_states = MTPLossAutoScaler.apply( - hidden_states, - mtp_loss_scale * mtp_loss, - ) + _mtp_loss_to_store = mtp_loss_scale * mtp_loss else: - hidden_states = MTPLossAutoScaler.apply( - hidden_states, - mtp_loss_scale * mtp_loss / num_tokens, + _mtp_loss_to_store = mtp_loss_scale * mtp_loss / num_tokens + _engine_ref._mtp_loss_for_backward.append(_mtp_loss_to_store) + if self_model.training and _logger.isEnabledFor(10): + _logger.debug( + f"[MTPFix] Stored MTP loss for backward: " + f"sum={_mtp_loss_to_store.sum().item():.6f}, " + f"requires_grad={_mtp_loss_to_store.requires_grad}, " + f"accumulator_len={len(_engine_ref._mtp_loss_for_backward)}" ) _logger.info( @@ -1474,7 +1479,8 @@ def _patched_get_embeddings( packed_seq_params=packed_seq_params, ) _ids, _pos, _dec_input, _hs = result - _dec_input = _dec_input.detach() + + _dec_input = _dec_input.detach().requires_grad_(True) from megatron.core.utils import ( make_viewless_tensor, ) @@ -1484,16 +1490,47 @@ def _patched_get_embeddings( requires_grad=True, keep_graph=False, ) - if not hasattr(_patched_get_embeddings, "_diag_done"): - _patched_get_embeddings._diag_done = True - import logging as _log_m - _ge_lg = _log_m.getLogger("MegatronEngine") - _hs_gfn = type(_hs.grad_fn).__name__ if _hs.grad_fn else "None" + + import logging as _log_m + _ge_lg = _log_m.getLogger("MegatronEngine") + + if not hasattr(_patched_get_embeddings, "_call_count"): + _patched_get_embeddings._call_count = 0 + _patched_get_embeddings._call_count += 1 + _call_n = _patched_get_embeddings._call_count + + if _call_n <= 4 or _call_n % 500 == 0: + _di_gfn = ( + type(_dec_input.grad_fn).__name__ + if _dec_input.grad_fn else "None(leaf)") + _hs_gfn = ( + type(_hs.grad_fn).__name__ + if _hs.grad_fn else "None(leaf)") _ge_lg.info( - "[MTPEmbDiag] _patched_get_embeddings: " - "_dec_input.rg=%s, _hs.rg=%s, _hs.grad_fn=%s", + "[MTPEmbDiag] _patched_get_embeddings " + "(call #%d): " + "_dec_input=[rg=%s, shape=%s, grad_fn=%s], " + "_hs=[rg=%s, shape=%s, grad_fn=%s]", + _call_n, _dec_input.requires_grad, - _hs.requires_grad, _hs_gfn) + list(_dec_input.shape), + _di_gfn, + _hs.requires_grad, + list(_hs.shape), + _hs_gfn, + ) + + if not _dec_input.requires_grad: + _ge_lg.error( + "[MTPEmbDiag] CRITICAL: _dec_input.requires_grad " + "is False! MTP gradients will be zero. " + "call #%d", _call_n) + if not _hs.requires_grad: + _ge_lg.error( + "[MTPEmbDiag] CRITICAL: _hs.requires_grad " + "is False! MTP gradients will be zero. " + "call #%d", _call_n) + return _ids, _pos, _dec_input, _hs _layer._get_embeddings = _patched_get_embeddings @@ -3222,6 +3259,14 @@ def _compute_logprobs_and_loss( total_loss_weight: torch.Tensor, loss_multiplier: float = 1.0, ) -> torch.Tensor: + _mtp_loss_for_this_mb = None + if ( + self.enable_mtp_training + and hasattr(self, '_mtp_loss_for_backward') + and self._mtp_loss_for_backward + ): + _mtp_loss_for_this_mb = self._mtp_loss_for_backward.pop(0) + local_weight = loss_weight_fn(inputs) if local_weight == 0: return output.mean() * 0.0 @@ -3280,6 +3325,18 @@ def _compute_logprobs_and_loss( loss = loss_fn(values, inputs) loss_scale = local_weight / total_loss_weight * loss_multiplier + + if _mtp_loss_for_this_mb is not None: + _mtp_contribution = _mtp_loss_for_this_mb.sum() + loss = loss + _mtp_contribution + if random.random() < 0.01: + self.logger.info( + f"[MTPFix] Added MTP loss to RL loss: " + f"mtp_contribution={_mtp_contribution.item():.6f}, " + f"rl_loss_before={(loss - _mtp_contribution).item():.6f}, " + f"combined_loss={loss.item():.6f}, loss_scale={loss_scale:.6f}" + ) + return loss * loss_scale def _compute_forward_result( From 7e4118a7cb93a237abb746bd2189b074326cc5c3 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 24 Apr 2026 15:29:14 +0800 Subject: [PATCH 068/140] feat(mtp): add mtp lr --- areal/api/cli_args.py | 9 +++ areal/engine/megatron_engine.py | 80 +++++++++++++++------ examples/math/gsm8k_grpo_megatron_mimo.yaml | 1 + 3 files changed, 69 insertions(+), 21 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index aa7d7042f5..677789e12d 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -374,6 +374,15 @@ class OptimizerConfig: gradient_clipping: float = field( default=1.0, metadata={"help": "Gradient clipping threshold"} ) + mtp_lr_scale: float = field( + default=1.0, + metadata={ + "help": "Learning rate scale factor for MTP parameters relative to base lr. " + "Effective MTP lr = lr * mtp_lr_scale. " + "Set to >1.0 (e.g., 100.0) to give MTP head a higher learning rate. " + "Only effective when enable_mtp_training=True.", + }, + ) @dataclass diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 4049b44feb..3f1ce4a2ea 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1010,10 +1010,20 @@ def optimizer_step(self): update_successful, grad_norm, _ = self.optimizer.step() current_lr = self.optimizer.param_groups[0]["lr"] + # Log MTP lr if using separate param group + _mtp_lr = None + if self.enable_mtp_training and len(self.optimizer.param_groups) > 1: + for _pg in self.optimizer.param_groups: + if _pg.get('max_lr', None) != self.optimizer.param_groups[0].get('max_lr', None): + _mtp_lr = _pg['lr'] + break + + return dict( update_successful=float(update_successful), grad_norm=float(grad_norm) if grad_norm is not None else float("nan"), lr=current_lr, + mtp_lr=_mtp_lr if _mtp_lr is not None else current_lr, ) def lr_scheduler_step(self): @@ -1459,6 +1469,7 @@ def _mtp_backward_hook(grad, _lg=_logger): for _layer in _mtp_block.layers: _orig_get_emb = _layer._get_embeddings + _emb_call_count = [0] # Closure variable for call counting def _patched_get_embeddings( input_ids, position_ids, @@ -1481,23 +1492,10 @@ def _patched_get_embeddings( _ids, _pos, _dec_input, _hs = result _dec_input = _dec_input.detach().requires_grad_(True) - from megatron.core.utils import ( - make_viewless_tensor, - ) - - _hs = make_viewless_tensor( - inp=_hs.detach(), - requires_grad=True, - keep_graph=False, - ) + _hs = _hs.detach().requires_grad_(True) - import logging as _log_m - _ge_lg = _log_m.getLogger("MegatronEngine") - - if not hasattr(_patched_get_embeddings, "_call_count"): - _patched_get_embeddings._call_count = 0 - _patched_get_embeddings._call_count += 1 - _call_n = _patched_get_embeddings._call_count + _emb_call_count[0] += 1 + _call_n = _emb_call_count[0] if _call_n <= 4 or _call_n % 500 == 0: _di_gfn = ( @@ -1506,12 +1504,13 @@ def _patched_get_embeddings( _hs_gfn = ( type(_hs.grad_fn).__name__ if _hs.grad_fn else "None(leaf)") - _ge_lg.info( + _engine_ref.logger.info( "[MTPEmbDiag] _patched_get_embeddings " - "(call #%d): " + "(call #%d, step=%d): " "_dec_input=[rg=%s, shape=%s, grad_fn=%s], " "_hs=[rg=%s, shape=%s, grad_fn=%s]", _call_n, + getattr(_engine_ref, '_global_step', -1), _dec_input.requires_grad, list(_dec_input.shape), _di_gfn, @@ -1521,12 +1520,12 @@ def _patched_get_embeddings( ) if not _dec_input.requires_grad: - _ge_lg.error( + _engine_ref.logger.error( "[MTPEmbDiag] CRITICAL: _dec_input.requires_grad " "is False! MTP gradients will be zero. " "call #%d", _call_n) if not _hs.requires_grad: - _ge_lg.error( + _engine_ref.logger.error( "[MTPEmbDiag] CRITICAL: _hs.requires_grad " "is False! MTP gradients will be zero. " "call #%d", _call_n) @@ -2150,7 +2149,46 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: torch, self.mcore_config.exp_avg_sq_dtype ) - self.optimizer = get_megatron_optimizer(mcore_opt_config, self.model) + + # --- MTP independent learning rate --- + _mtp_lr_config_overrides = None + _mtp_lr_scale = getattr(self.optimizer_config, 'mtp_lr_scale', 1.0) + if self.enable_mtp_training and _mtp_lr_scale != 1.0: + try: + from megatron.core.optimizer.optimizer_config import ParamKey + except ImportError: + ParamKey = None + if ParamKey is not None: + _mtp_lr = self.optimizer_config.lr * _mtp_lr_scale + _mtp_min_lr = ( + self.optimizer_config.min_lr_ratio + * self.optimizer_config.lr + * _mtp_lr_scale + ) + # Match all MTP parameters by name glob pattern + _mtp_param_key = ParamKey(name=("*.mtp.*",)) + _mtp_lr_config_overrides = { + _mtp_param_key: { + "max_lr": _mtp_lr, + "min_lr": _mtp_min_lr, + } + } + self.logger.info( + "[MTPOptim] MTP parameters will use separate lr: " + "max_lr=%.2e (scale=%.1fx), min_lr=%.2e, base_lr=%.2e", + _mtp_lr, _mtp_lr_scale, _mtp_min_lr, + self.optimizer_config.lr, + ) + else: + self.logger.warning( + "[MTPOptim] ParamKey not available in this megatron-core " + "version. MTP parameters will use the global learning rate." + ) + + self.optimizer = get_megatron_optimizer( + mcore_opt_config, self.model, + config_overrides=_mtp_lr_config_overrides, + ) warmup_steps_proportion = self.optimizer_config.warmup_steps_proportion warmup_steps = int(warmup_steps_proportion * ft_spec.total_train_steps) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index eea77f6908..776625de48 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -60,6 +60,7 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 + mtp_lr_scale: 100.0 eps_clip: 0.4 temperature: ${gconfig.temperature} reward_scaling: 10.0 From 227677126f6f46d4d6eb338bfacb0c8d164f16c2 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 24 Apr 2026 19:35:33 +0800 Subject: [PATCH 069/140] fix(engine): add mtp clip --- areal/engine/megatron_engine.py | 81 +++++++++++++++++++-- examples/math/gsm8k_grpo_megatron_mimo.yaml | 6 +- 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 3f1ce4a2ea..d47830c78a 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1139,6 +1139,11 @@ def forward_step(batch_iter, model): if self.enable_mtp_training: _engine_ref = self self._mtp_loss_for_backward = [] + # MTP loss EMA for adaptive clipping (prevents loss spikes) + if not hasattr(self, '_mtp_loss_ema'): + self._mtp_loss_ema = None # Will be initialized on first MTP loss + self._mtp_loss_clip_count = 0 + self._mtp_loss_total_count = 0 _unwrapped = model while hasattr(_unwrapped, "module"): @@ -1811,6 +1816,11 @@ def process_output( loss_multiplier=loss_multiplier, ) + # Track global training step for diagnostic logging + if not hasattr(self, '_global_step'): + self._global_step = 0 + self._global_step += 1 + self.forward_backward_batch(mb_list, process_output, forward_only=False) DeviceRuntimeInfo.get_current().log("train_batch after forward_backward") @@ -2569,6 +2579,20 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta) -> None: self.engine_lock.release() + # Log MTP weight norms for drift monitoring + if mtp_hf_tensors: + _norm_strs = [] + for _tn, _tv in mtp_hf_tensors[:5]: # Log first 5 + _norm_strs.append( + f"{_tn}:{_tv.float().norm().item():.4f}" + ) + self.logger.info( + "[MTPSyncDiag] MTP weight norms at sync " + "(version=%d, %d tensors): %s", + getattr(self, '_version', -1), + len(mtp_hf_tensors), + ", ".join(_norm_strs), + ) def _serialize_mtp_tensors_for_update( self, mtp_hf_tensors: list[tuple[str, torch.Tensor]], @@ -3365,14 +3389,59 @@ def _compute_logprobs_and_loss( loss_scale = local_weight / total_loss_weight * loss_multiplier if _mtp_loss_for_this_mb is not None: - _mtp_contribution = _mtp_loss_for_this_mb.sum() + _mtp_contribution_raw = _mtp_loss_for_this_mb.sum() + # --- MTP loss adaptive clipping (Fix: prevent loss spike feedback loop) --- + # When mtp_detach_heads=True, MTP trains independently of backbone. + # A sudden MTP loss spike (e.g., 5x normal) causes large gradient + # updates that destabilize the draft model, crashing accept rate, + # which in turn produces worse training data -> even higher loss. + # Clipping breaks this positive feedback loop. + _mtp_clip_threshold = 5.0 # Clip if loss > 5x EMA + _mtp_ema_decay = 0.95 + _mtp_contribution = _mtp_contribution_raw + _mtp_was_clipped = False + self._mtp_loss_total_count += 1 + if self._mtp_loss_ema is None: + # Initialize EMA with first observed value + self._mtp_loss_ema = _mtp_contribution_raw.detach().item() + else: + _raw_val = _mtp_contribution_raw.detach().item() + _ema_val = self._mtp_loss_ema + if _ema_val > 0 and _raw_val > _mtp_clip_threshold * _ema_val: + # Clip: scale down to threshold * EMA + _clip_ratio = (_mtp_clip_threshold * _ema_val) / _raw_val + _mtp_contribution = _mtp_contribution_raw * _clip_ratio + _mtp_was_clipped = True + self._mtp_loss_clip_count += 1 + self.logger.warning( + "[MTPLossClip] MTP loss clipped: raw=%.4f, ema=%.4f, " + "threshold=%.1fx, clip_ratio=%.4f, clipped=%.4f, " + "clip_count=%d/%d", + _raw_val, _ema_val, _mtp_clip_threshold, + _clip_ratio, _mtp_contribution.detach().item(), + self._mtp_loss_clip_count, self._mtp_loss_total_count, + ) + # Update EMA (use raw value for stable tracking, not clipped) + self._mtp_loss_ema = ( + _mtp_ema_decay * _ema_val + + (1 - _mtp_ema_decay) * _raw_val + ) loss = loss + _mtp_contribution - if random.random() < 0.01: + _n = self._mtp_loss_total_count + if _n <= 4 or _n % 100 == 0: self.logger.info( - f"[MTPFix] Added MTP loss to RL loss: " - f"mtp_contribution={_mtp_contribution.item():.6f}, " - f"rl_loss_before={(loss - _mtp_contribution).item():.6f}, " - f"combined_loss={loss.item():.6f}, loss_scale={loss_scale:.6f}" + "[MTPLossDiag] MTP loss added to RL loss (call #%d): " + "raw=%.6f, applied=%.6f, clipped=%s, " + "ema=%.6f, rl_before=%.6f, combined=%.6f, " + "loss_scale=%.6f", + _n, + _mtp_contribution_raw.detach().item(), + _mtp_contribution.detach().item(), + _mtp_was_clipped, + self._mtp_loss_ema if self._mtp_loss_ema else 0.0, + (loss - _mtp_contribution).detach().item(), + loss.detach().item(), + loss_scale, ) return loss * loss_scale diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 776625de48..49dbe9bb02 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -60,7 +60,7 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 - mtp_lr_scale: 100.0 + mtp_lr_scale: 10.0 eps_clip: 0.4 temperature: ${gconfig.temperature} reward_scaling: 10.0 @@ -82,7 +82,7 @@ actor: # MTP Online Training - keeps EAGLE draft heads aligned with evolving policy enable_mtp_training: true mtp_num_layers: 1 - mtp_loss_scaling_factor: 0.2 + mtp_loss_scaling_factor: 0.1 scheduling_spec: - task_type: worker @@ -94,7 +94,7 @@ actor: megatron: mtp_num_layers: 1 - mtp_loss_scaling_factor: 0.2 + mtp_loss_scaling_factor: 0.1 ref: backend: ${actor.backend} From b4f55437d83dea910349057e0e9fcb85af427a52 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 24 Apr 2026 20:58:03 +0800 Subject: [PATCH 070/140] refactor(megatron_engine): mv --- areal/engine/megatron_engine.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index d47830c78a..24a77434e4 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2579,20 +2579,6 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta) -> None: self.engine_lock.release() - # Log MTP weight norms for drift monitoring - if mtp_hf_tensors: - _norm_strs = [] - for _tn, _tv in mtp_hf_tensors[:5]: # Log first 5 - _norm_strs.append( - f"{_tn}:{_tv.float().norm().item():.4f}" - ) - self.logger.info( - "[MTPSyncDiag] MTP weight norms at sync " - "(version=%d, %d tensors): %s", - getattr(self, '_version', -1), - len(mtp_hf_tensors), - ", ".join(_norm_strs), - ) def _serialize_mtp_tensors_for_update( self, mtp_hf_tensors: list[tuple[str, torch.Tensor]], @@ -2938,6 +2924,20 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"buffer_size={buffer_size}" ) + if mtp_hf_tensors: + _norm_strs = [] + for _tn, _tv in mtp_hf_tensors[:5]: + _norm_strs.append( + f"{_tn}:{_tv.float().norm().item():.4f}" + ) + self.logger.info( + "[MTPSyncDiag] MTP weight norms at sync " + "(version=%d, %d tensors): %s", + meta.version, + len(mtp_hf_tensors), + ", ".join(_norm_strs), + ) + # Record a CUDA event on the default stream BEFORE any NCCL # broadcasts begin. At this point, all MTP tensors from # _collect_param()'s synchronous dist.all_gather() are fully From 00e44979785dd79f87a8890168d7471fa9c619d1 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 24 Apr 2026 23:56:07 +0800 Subject: [PATCH 071/140] feat(megatron_engine): ad --- areal/engine/megatron_engine.py | 159 +++++++++++++++++++++++++------- 1 file changed, 125 insertions(+), 34 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 24a77434e4..3346b86e5d 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1144,6 +1144,19 @@ def forward_step(batch_iter, model): self._mtp_loss_ema = None # Will be initialized on first MTP loss self._mtp_loss_clip_count = 0 self._mtp_loss_total_count = 0 + # [v5-F6] Hint SpecDec v2 env toggle for throughput (idempotent, + # rank-0 only to avoid N-rank log spam). + import os as _os_v5 + try: + _rank_v5 = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + except Exception: + _rank_v5 = 0 + if _rank_v5 == 0 and _os_v5.environ.get("SGLANG_ENABLE_SPEC_V2", "") == "": + self.logger.info( + "[MTPEnvHint] SGLANG_ENABLE_SPEC_V2 not set; " + "consider exporting SGLANG_ENABLE_SPEC_V2=True to " + "enable overlap scheduler for speculative decoding." + ) _unwrapped = model while hasattr(_unwrapped, "module"): @@ -1304,13 +1317,16 @@ def _patched_postprocess( # mtp_loss_scaling_factor) and matches # Megatron-Core's native implementation. _mtp_hs = hidden_states_list[mtp_layer_number + 1] - if _mtp_diag_mb_counter[0] == 0: + # [v5-F1c] Gate MB#0 forward diag to first 3 steps + every 100. + _gs_fwd = getattr(_engine_ref, '_global_step', 0) + if (_mtp_diag_mb_counter[0] == 0 + and (_gs_fwd <= 3 or _gs_fwd % 100 == 0)): _mtp_hs_gfn = type(_mtp_hs.grad_fn).__name__ if _mtp_hs.grad_fn else "None" _logger.info( - "[MTPFwdDiag] MB#0 Layer#%d: " + "[MTPFwdDiag] MB#0 Layer#%d step=%d: " "_mtp_hs.rg=%s, shape=%s, grad_fn=%s, " "hs.rg=%s", - mtp_layer_number, _mtp_hs.requires_grad, + mtp_layer_number, _gs_fwd, _mtp_hs.requires_grad, list(_mtp_hs.shape), _mtp_hs_gfn, hidden_states.requires_grad) mtp_logits, _ = self_model.output_layer( @@ -1344,12 +1360,15 @@ def _patched_postprocess( mtp_labels, mtp_logits ) mtp_loss = loss_mask * mtp_loss - if _mtp_diag_mb_counter[0] == 0: + # [v5-F1c] Gate MB#0 mtp_loss diag to first 3 steps + every 100. + _gs_ml = getattr(_engine_ref, '_global_step', 0) + if (_mtp_diag_mb_counter[0] == 0 + and (_gs_ml <= 3 or _gs_ml % 100 == 0)): _ml_gfn = type(mtp_loss.grad_fn).__name__ if mtp_loss.grad_fn else "None" _logger.info( - "[MTPFwdDiag] MB#0 mtp_loss: " + "[MTPFwdDiag] MB#0 mtp_loss step=%d: " "rg=%s, grad_fn=%s, sum=%.6f, num_tokens=%s", - mtp_loss.requires_grad, _ml_gfn, + _gs_ml, mtp_loss.requires_grad, _ml_gfn, mtp_loss.sum().item(), num_tokens) elif self_model.training and _logger.isEnabledFor(10): _logger.debug( @@ -1377,32 +1396,49 @@ def _patched_postprocess( else: _mtp_loss_to_store = mtp_loss_scale * mtp_loss / num_tokens _engine_ref._mtp_loss_for_backward.append(_mtp_loss_to_store) + # [v5-F4] Cap FIFO to avoid unbounded growth on producer/consumer drift. + _fifo_len = len(_engine_ref._mtp_loss_for_backward) + if _fifo_len > 32: + _logger.warning( + "[MTPFifoOverflow] MTP loss FIFO length=%d >32, " + "dropping oldest entry (producer-consumer drift).", + _fifo_len, + ) + _engine_ref._mtp_loss_for_backward.pop(0) if self_model.training and _logger.isEnabledFor(10): _logger.debug( f"[MTPFix] Stored MTP loss for backward: " f"sum={_mtp_loss_to_store.sum().item():.6f}, " f"requires_grad={_mtp_loss_to_store.requires_grad}, " - f"accumulator_len={len(_engine_ref._mtp_loss_for_backward)}" + f"accumulator_len={_fifo_len}" ) - _logger.info( - "[MTPDetach] MTP loss computed via direct output_layer call") + # [v5-F1a] Gate per-step to first MB to avoid 1.4k lines/step spam. + if _mtp_diag_mb_counter[0] == 0: + _logger.info( + "[MTPDetach] MTP loss computed via direct output_layer call (first MB of step)") + # [v5-F1b] Gate backward hook registration to first 3 steps + # then every 100 steps; previously fired every step × every MB#0. + _gs_v5 = getattr(_engine_ref, '_global_step', 0) + _should_log_bwd = (_gs_v5 <= 3 or _gs_v5 % 100 == 0) if (_mtp_diag_mb_counter[0] == 0 - and hidden_states.requires_grad): - def _mtp_backward_hook(grad, _lg=_logger): + and hidden_states.requires_grad + and _should_log_bwd): + def _mtp_backward_hook(grad, _lg=_logger, _gs=_gs_v5): + # Inner hook fires once per backward; log only on gated steps. _lg.info( - "[MTPBwdDiag] AutoScaler backward FIRED: " + "[MTPBwdDiag] AutoScaler backward FIRED (step=%d): " "grad.shape=%s, grad.norm=%.8f, " "grad.abs_max=%.8f", - list(grad.shape), + _gs, list(grad.shape), grad.float().norm().item(), grad.float().abs().max().item()) hidden_states.register_hook(_mtp_backward_hook) _logger.info( "[MTPFwdDiag] MB#0 Registered backward hook on " - "hidden_states(post-AutoScaler): shape=%s, rg=%s", - list(hidden_states.shape), + "hidden_states(post-AutoScaler) step=%d: shape=%s, rg=%s", + _gs_v5, list(hidden_states.shape), hidden_states.requires_grad) _mtp_diag_mb_counter[0] += 1 @@ -1502,7 +1538,8 @@ def _patched_get_embeddings( _emb_call_count[0] += 1 _call_n = _emb_call_count[0] - if _call_n <= 4 or _call_n % 500 == 0: + # [v5-F1d] Relax throttle 500->2000 to cut MTPEmbDiag spam ~4x. + if _call_n <= 4 or _call_n % 2000 == 0: _di_gfn = ( type(_dec_input.grad_fn).__name__ if _dec_input.grad_fn else "None(leaf)") @@ -1577,11 +1614,14 @@ def _mtp_loss_fn( _orig=_orig_clm, _lg=self.logger, ): - if _mtp_diag_mb_counter[0] <= 2: + # [v5-F1e] Gate LossFn diag to MB#0 of first 3 steps + every 100. + _gs_lfn = getattr(_engine_ref, '_global_step', 0) + if (_mtp_diag_mb_counter[0] == 0 + and (_gs_lfn <= 3 or _gs_lfn % 100 == 0)): _lg.info( - "[MTPLossFnDiag] _mtp_loss_fn called: " + "[MTPLossFnDiag] _mtp_loss_fn called step=%d: " "_rem=%d, _logits.rg=%s, shape=%s", - _rem[0], _logits.requires_grad, + _gs_lfn, _rem[0], _logits.requires_grad, list(_logits.shape)) if _rem[0] > 0: _rem[0] -= 1 @@ -2371,7 +2411,8 @@ def _collect_param( _has_tmp = hasattr(param, "tensor_model_parallel") _is_tmp = getattr(param, "tensor_model_parallel", False) if _has_tmp else False _is_dup = name in self._duplicated_param_names if self._duplicated_param_names else False - self.logger.info( + # [v5-F1f] Downgrade per-param trace to DEBUG (was INFO, ~21k lines/run). + self.logger.debug( f"[DiagImpl] Rank {dist.get_rank()} all_gather_param START " f"name={name}, has_tmp={_has_tmp}, is_tmp={_is_tmp}, is_dup={_is_dup}, " f"param_shape={tuple(param.shape)}, param_dtype={param.dtype}" @@ -2383,7 +2424,8 @@ def _collect_param( quantization_config=self.quantization_config, duplicated_param_names=self._duplicated_param_names, ) - self.logger.info( + # [v5-F1f] Downgrade per-param trace to DEBUG. + self.logger.debug( f"[DiagImpl] Rank {dist.get_rank()} all_gather_param DONE " f"name={name}, result_type={type(param).__name__}" ) @@ -2409,12 +2451,14 @@ def _impl_update_weight_from_distributed( import time as _diag_time _t0 = _diag_time.time() - self.logger.info( + # [v5-F1f] Downgrade per-param trace to DEBUG. + self.logger.debug( f"[DiagImpl] Rank {dist.get_rank()} _collect_param START " f"name={name}" ) param, param_size = self._collect_param(name, param) - self.logger.info( + # [v5-F1f] Downgrade per-param trace to DEBUG. + self.logger.debug( f"[DiagImpl] Rank {dist.get_rank()} _collect_param DONE " f"name={name}, param_size={param_size / 1024 / 1024:.2f} MB, " f"took={_diag_time.time() - _t0:.3f}s" @@ -2925,18 +2969,65 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) if mtp_hf_tensors: - _norm_strs = [] - for _tn, _tv in mtp_hf_tensors[:5]: - _norm_strs.append( - f"{_tn}:{_tv.float().norm().item():.4f}" + # [v5-F3] Compute norms for ALL tensors (was: only first 5). + # [v5-F5] Track prev norm per-tensor to surface drift direction + # and detect stall (draft model not learning from RL data). + if not hasattr(self, "_mtp_sync_prev_norms"): + self._mtp_sync_prev_norms = {} + _all_norms = [] + _deltas = [] + _stall_tensors = [] + for _tn, _tv in mtp_hf_tensors: + _cur = _tv.float().norm().item() + _prev = self._mtp_sync_prev_norms.get(_tn) + if _prev is None: + _all_norms.append((_tn, _cur, None)) + else: + _d = _cur - _prev + _all_norms.append((_tn, _cur, _d)) + _deltas.append(abs(_d)) + # Stall: weight changed by <1e-5 absolute (essentially frozen). + if _cur > 0 and abs(_d) < 1e-5: + _stall_tensors.append(_tn) + self._mtp_sync_prev_norms[_tn] = _cur + # Compact per-tensor summary line (rank-0 only to avoid DP-spam). + try: + _rank_v5 = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + except Exception: + _rank_v5 = 0 + if _rank_v5 == 0: + _fmt_parts = [] + for _tn, _cur, _d in _all_norms: + if _d is None: + _fmt_parts.append(f"{_tn}:{_cur:.4f}") + else: + _fmt_parts.append(f"{_tn}:{_cur:.4f}(Δ{_d:+.3e})") + _drift_summary = "" + if _deltas: + _max_d = max(_deltas) + _sum_d = sum(_deltas) + _drift_summary = f" | max|Δ|={_max_d:.3e} sum|Δ|={_sum_d:.3e}" + self.logger.info( + "[MTPSyncDiag] MTP weight norms at sync " + "(version=%d, %d tensors): %s%s", + meta.version, + len(mtp_hf_tensors), + ", ".join(_fmt_parts), + _drift_summary, ) - self.logger.info( - "[MTPSyncDiag] MTP weight norms at sync " - "(version=%d, %d tensors): %s", - meta.version, - len(mtp_hf_tensors), - ", ".join(_norm_strs), - ) + # [v5-F5] Stall warning: if >50% of MTP tensors show sub-1e-5 drift, + # the draft model isn't learning — root cause of accept-rate collapse. + if _deltas and len(_stall_tensors) >= 0.5 * len(_deltas): + self.logger.warning( + "[MTPSyncHealth] MTP training STALL detected at version=%d: " + "%d/%d tensors drift<1e-5. " + "Likely causes: (1) mtp_lr_scale too small, " + "(2) mtp_loss_scaling_factor too small, " + "(3) MTP gradient is being zeroed by detach. " + "Accept-rate collapse will follow. Stalled tensors (head): %s", + meta.version, len(_stall_tensors), len(_deltas), + ", ".join(_stall_tensors[:3]), + ) # Record a CUDA event on the default stream BEFORE any NCCL # broadcasts begin. At this point, all MTP tensors from From 3117ccf6eb33363337977d7a3e8133c473b30590 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 27 Apr 2026 15:04:20 +0800 Subject: [PATCH 072/140] fix: h20 config --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 49dbe9bb02..8b777cbbbd 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -19,7 +19,7 @@ scheduler: type: null rollout: - backend: "sglang:d1p1t2" + backend: "sglang:d1p1t1" experiment_name: ${experiment_name} trial_name: ${trial_name} max_concurrent_rollouts: 256 @@ -40,7 +40,7 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:d1p1t4" + backend: "megatron:d2p1t2" experiment_name: ${experiment_name} trial_name: ${trial_name} path: /workspace/models/MiMo-7B-RL @@ -51,7 +51,7 @@ actor: mb_spec: max_tokens_per_mb: 2048 optimizer: - type: adam_bf16 + type: adam lr: 3e-6 weight_decay: 0.003 beta1: 0.9 From 6dba807ad68fd14abb9f997b528f49cfa20e6c58 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 27 Apr 2026 15:14:42 +0800 Subject: [PATCH 073/140] perf: fix config --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 8b777cbbbd..b7624a2cc6 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -25,7 +25,7 @@ rollout: max_concurrent_rollouts: 256 queue_size: null consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 0 + max_head_offpolicyness: 1 enable_rollout_tracing: false scheduling_spec: ${actor.scheduling_spec} fileroot: ${cluster.fileroot} @@ -49,7 +49,7 @@ actor: gradient_checkpointing: true dtype: bfloat16 mb_spec: - max_tokens_per_mb: 2048 + max_tokens_per_mb: 10240 optimizer: type: adam lr: 3e-6 @@ -60,7 +60,6 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 - mtp_lr_scale: 10.0 eps_clip: 0.4 temperature: ${gconfig.temperature} reward_scaling: 10.0 @@ -82,7 +81,7 @@ actor: # MTP Online Training - keeps EAGLE draft heads aligned with evolving policy enable_mtp_training: true mtp_num_layers: 1 - mtp_loss_scaling_factor: 0.1 + mtp_loss_scaling_factor: 0.2 scheduling_spec: - task_type: worker @@ -94,7 +93,7 @@ actor: megatron: mtp_num_layers: 1 - mtp_loss_scaling_factor: 0.1 + mtp_loss_scaling_factor: 0.2 ref: backend: ${actor.backend} @@ -105,7 +104,7 @@ ref: disable_dropout: true dtype: ${actor.dtype} mb_spec: - max_tokens_per_mb: 2048 + max_tokens_per_mb: 10240 optimizer: null scheduling_strategy: type: colocation @@ -120,7 +119,7 @@ sglang: dtype: ${actor.dtype} max_running_requests: null context_length: 32768 - mem_fraction_static: 0.4 + mem_fraction_static: 0.8 # EAGLE speculative decoding settings speculative_algorithm: "EAGLE" @@ -142,7 +141,7 @@ vllm: # datasets train_dataset: - batch_size: 128 + batch_size: 256 shuffle: true pin_memory: true num_workers: 4 @@ -151,7 +150,7 @@ train_dataset: max_length: 1024 valid_dataset: - batch_size: 128 + batch_size: 256 pin_memory: true num_workers: 4 path: openai/gsm8k From a9161e7b941c4371f16cb529133cc8dcab458fd0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 27 Apr 2026 16:07:34 +0800 Subject: [PATCH 074/140] feat: add log --- areal/engine/megatron_engine.py | 85 +++++++++-- areal/infra/controller/train_controller.py | 84 ++++++----- areal/infra/rpc/guard/engine_blueprint.py | 164 +++++++++++++-------- areal/infra/scheduler/local.py | 117 +++++++++------ areal/trainer/rl_trainer.py | 77 +++++++--- 5 files changed, 353 insertions(+), 174 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 3346b86e5d..d6edcfd5ea 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -211,20 +211,41 @@ def __init__(self, config: TrainEngineConfig): ) def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): + import time as _time + + _t0 = _time.time() if parallel_strategy is None: parallel_strategy = ParallelStrategy() self.parallel_strategy = self._make_parallel_strategy(parallel_strategy) backend = current_platform.communication_backend + if not dist.is_initialized(): - # NOTE: device_id **SHOULD NOT** be passed into init_process_group, - # otherwise initializing the NCCL weight update group will be wrong! + self.logger.info( + "[DiagInit] create_process_group: calling dist.init_process_group " + f"(backend={backend}, RANK={os.environ.get('RANK')}, " + f"WORLD_SIZE={os.environ.get('WORLD_SIZE')})..." + ) + _t1 = _time.time() dist.init_process_group( backend=backend, timeout=DIST_GROUP_DEFAULT_TIMEOUT, ) - # Initialize Megatron parallel states - # NOTE: we assume all MegatronEngine has the same parallel strategy. + self.logger.info( + f"[DiagInit] create_process_group: dist.init_process_group done in " + f"{_time.time() - _t1:.2f}s" + ) + vpp_size = self.parallel_strategy.virtual_pipeline_parallel_size + self.logger.info( + f"[DiagInit] create_process_group: calling mpu.initialize_model_parallel " + f"(tp={self.parallel_strategy.tensor_parallel_size}, " + f"pp={self.parallel_strategy.pipeline_parallel_size}, " + f"cp={self.parallel_strategy.context_parallel_size}, " + f"ep={self.parallel_strategy.expert_parallel_size}, " + f"etp={self.parallel_strategy.expert_tensor_parallel_size}, " + f"vpp={vpp_size})..." + ) + _t2 = _time.time() mpu.initialize_model_parallel( tensor_model_parallel_size=self.parallel_strategy.tensor_parallel_size, pipeline_model_parallel_size=self.parallel_strategy.pipeline_parallel_size, @@ -238,17 +259,28 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None DIST_GROUP_DEFAULT_TIMEOUT.seconds / 60 ), ) - # Set megatron model parallel seed + self.logger.info( + f"[DiagInit] create_process_group: mpu.initialize_model_parallel done in " + f"{_time.time() - _t2:.2f}s" + ) + tensor_parallel.model_parallel_cuda_manual_seed(self.seed) self.own_global_group = True + else: + self.logger.info( + "[DiagInit] create_process_group: dist already initialized, skipping init_process_group" + ) + self.logger = logging.getLogger(f"[MegatronEngine Rank {dist.get_rank()}]") self._context_and_model_parallel_group = None self._init_context_and_model_parallel_group() - # This is needed for barrier synchronization when models are moved to CPU self._cpu_group = dist.new_group( timeout=DIST_GROUP_DEFAULT_TIMEOUT, backend="gloo" ) self.process_group_initialized = True + self.logger.info( + f"[DiagInit] create_process_group: COMPLETED in {_time.time() - _t0:.2f}s" + ) def _apply_megatron_bridge_lora(self) -> None: assert self.model is not None, "Model must be initialized before applying LoRA." @@ -287,6 +319,11 @@ def _apply_megatron_bridge_lora(self) -> None: ) def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): + import time as _time + + _t0 = _time.time() + self.logger.info("[DiagInit] initialize: ENTERED") + try: self.seed = get_seed() except ValueError: @@ -313,6 +350,10 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): f"update_weight_group_{mpu.get_pipeline_model_parallel_rank()}" ) self.engine_lock = DistributedLock("train_engine_lock") + self.logger.info( + f"[DiagInit] initialize: rank={self.rank}, world_size={self.world_size}, " + f"device={self.device}, is_pp_head={self.is_pp_head}" + ) if self.config.use_lora and self.bridge_cls != "megatron-bridge": raise NotImplementedError( @@ -320,13 +361,21 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): "mbridge does not support LoRA in this path." ) + self.logger.info("[DiagInit] initialize: loading tokenizer...") + _t1 = _time.time() self.tokenizer = load_hf_tokenizer(self.config.path) + self.logger.info(f"[DiagInit] initialize: tokenizer loaded in {_time.time() - _t1:.2f}s") + self.logger.info("[DiagInit] initialize: building HF/Megatron bridge...") + _t2 = _time.time() with patch_bridge_for_tree_training( self.enable_tree_training and self.bridge_cls == "mbridge" ): self.bridge = self._build_hf_mcore_bridge() + self.logger.info(f"[DiagInit] initialize: bridge built in {_time.time() - _t2:.2f}s") + self.logger.info("[DiagInit] initialize: making HF and mcore config...") + _t3 = _time.time() self.hf_config, self.tf_config = make_hf_and_mcore_config( self.config.path, dtype=self.dtype, @@ -336,6 +385,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self.tf_config = configure_pipeline_layer_splits( self.parallel_strategy, self.hf_config, self.tf_config ) + self.logger.info(f"[DiagInit] initialize: configs made in {_time.time() - _t3:.2f}s") self.quantization_config = getattr( self.hf_config, "quantization_config", None @@ -344,7 +394,6 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self._check_and_apply_fp8_config() self._validate_fp8_consistency() - # Propagate MTP config to tf_config (TransformerConfig) for model creation if self.enable_mtp_training: self.tf_config.mtp_num_layers = self.mtp_num_layers self.tf_config.mtp_loss_scaling_factor = self.mtp_loss_scaling_factor @@ -357,13 +406,6 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): f"mtp_detach_heads={self.mtp_detach_heads}" ) else: - # When MTP training is disabled, clear mtp_num_layers to - # prevent mbridge from creating MTP layers. Without this, - # models like MiMo whose HF config contains - # num_nextn_predict_layers>0 would still create MTP layers - # through mbridge, causing _postprocess() to enter the MTP - # loss path and crash on labels.clone() when labels is None - # during inference. _orig_mtp = getattr(self.tf_config, "mtp_num_layers", None) if _orig_mtp is not None and _orig_mtp > 0: self.tf_config.mtp_num_layers = None @@ -373,6 +415,8 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): f"MTP layers will NOT be created in GPTModel." ) + self.logger.info("[DiagInit] initialize: creating Megatron model...") + _t4 = _time.time() with self.device: models = make_mcore_model( hf_config=self.hf_config, @@ -384,14 +428,21 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): use_lora=self.config.use_lora, enable_mtp=self.enable_mtp_training, ) + self.logger.info(f"[DiagInit] initialize: Megatron model created in {_time.time() - _t4:.2f}s") self.model = _MegatronModelList(models) if self.config.use_lora: + self.logger.info("[DiagInit] initialize: applying Megatron Bridge LoRA...") + _t_lora = _time.time() self._apply_megatron_bridge_lora() + self.logger.info(f"[DiagInit] initialize: LoRA applied in {_time.time() - _t_lora:.2f}s") + self.logger.info("[DiagInit] initialize: loading model weights from HF...") + _t5 = _time.time() with self.device: self._load_model_from_hf(self.config.path) + self.logger.info(f"[DiagInit] initialize: HF weights loaded in {_time.time() - _t5:.2f}s") # NOTE: Clear high_precision_init_val for FP8 parameters. # @@ -464,7 +515,10 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): if len(self.model) == 1: model_config.param_sync_func = model_config.param_sync_func[0] model_config.finalize_model_grads_func = finalize_model_grads + self.logger.info("[DiagInit] initialize: creating optimizer...") + _t6 = _time.time() self._create_optimizer(ft_spec) + self.logger.info(f"[DiagInit] initialize: optimizer created in {_time.time() - _t6:.2f}s") if self.enable_mtp_training and not self._mtp_layers_verified: mtp_param_count = 0 @@ -514,6 +568,9 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): ) self._initialized = True + self.logger.info( + f"[DiagInit] initialize: COMPLETED in {_time.time() - _t0:.2f}s total" + ) def _build_hf_mcore_bridge(self): if self.bridge_cls == "mbridge": diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index 59bef0a85b..91e2cdad9b 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -255,25 +255,12 @@ def initialize( ft_spec: FinetuneSpec, **kwargs, ): - """Initialize environments for distributed training and load models. + import time as _time - Parameters - ---------- - role : str - Role identifier for the workers - ft_spec : FinetuneSpec - Finetune specification for model initialization - **kwargs - Additional keyword arguments passed to engine initialization - """ - # Store configuration self._worker_role = role world_size = self.train_alloc.parallel.world_size - # Create job specification for scheduler - # Convert scheduling_spec tuple to list for scheduler compatibility - # The scheduler will handle task replication across workers if needed job = Job( replicas=world_size, tasks=list(self.config.scheduling_spec), @@ -281,19 +268,16 @@ def initialize( role=self._worker_role, ) - # Create workers via scheduler logger.info("Creating workers via scheduler...") + _t0 = _time.time() worker_ids = self.scheduler.create_workers(job=job) - logger.info(f"Workers created: {worker_ids}") + logger.info(f"Workers created: {worker_ids} in {_time.time() - _t0:.2f}s") - # Wait for workers to be ready logger.info("Waiting for workers to be ready...") + _t1 = _time.time() self.workers = self.scheduler.get_workers(role=job.role) - logger.info(f"Workers ready: {[w.id for w in self.workers]}") + logger.info(f"Workers ready: {[w.id for w in self.workers]} in {_time.time() - _t1:.2f}s") - # Determine distributed training master address and port from rank 0 worker - # These are used for PyTorch distributed initialization across workers - # Prefer engine_ports[1] if available, fallback to worker_ports[1] rank0_worker = self.workers[0] if rank0_worker.engine_ports: self._master_port = int(rank0_worker.engine_ports[1]) @@ -305,20 +289,26 @@ def initialize( f"Distributed training: MASTER_ADDR={self._master_addr}, MASTER_PORT={self._master_port}" ) - # Construct engine class import path for dynamic loading on workers - # Workers will import and instantiate the engine class using this path engine_class = self.train_engine + engine_path = f"{engine_class.__module__}.{engine_class.__name__}" - # Create and initialize engines on workers + logger.info(f"Creating engines on workers (class={engine_path})...") + _t2 = _time.time() run_async_task( self._async_create_engines, - f"{engine_class.__module__}.{engine_class.__name__}", + engine_path, ) + logger.info(f"Engines created on all workers in {_time.time() - _t2:.2f}s") + + logger.info("Initializing engines on workers...") + _t3 = _time.time() run_async_task(self._async_initialize_engines, ft_spec, **kwargs) + logger.info(f"All engines initialized in {_time.time() - _t3:.2f}s") - # Identify DP head workers self._identify_dp_heads() - logger.info("TrainController initialization complete") + logger.info( + f"TrainController initialization complete (total: {_time.time() - _t0:.2f}s)" + ) def _engine_name(self, rank: int) -> str: """Generate engine name for a worker rank. @@ -328,35 +318,53 @@ def _engine_name(self, rank: int) -> str: return f"{self._worker_role}/{rank}" async def _async_create_engines(self, engine: str): - """Create engine instances on all workers. Sets distributed env vars before creation.""" + import time as _time + logger.info("Creating engines on workers...") async def _setup_worker(worker: Worker, rank: int): + _wt0 = _time.time() env = { "RANK": str(rank), "WORLD_SIZE": str(len(self.workers)), "MASTER_ADDR": str(self._master_addr), "MASTER_PORT": str(self._master_port), - "LOCAL_RANK": "0", # NOTE: local rank is always 0 while each process use only one GPU + "LOCAL_RANK": "0", } + logger.info( + f"[DiagInit] _setup_worker {worker.id}: sending /set_env " + f"(RANK={rank}, WORLD_SIZE={len(self.workers)})..." + ) await self.scheduler.set_worker_env(worker.id, env) + logger.info( + f"[DiagInit] _setup_worker {worker.id}: /set_env done in " + f"{_time.time() - _wt0:.2f}s, creating engine..." + ) await self.scheduler.create_engine( worker_id=worker.id, engine=engine, engine_name=self._engine_name(rank), config=self.config, ) + logger.info( + f"[DiagInit] _setup_worker {worker.id}: engine created in " + f"{_time.time() - _wt0:.2f}s total" + ) + _t0 = _time.time() tasks = [ _setup_worker(worker, rank) for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) - logger.info("Engines created on all workers!") + logger.info(f"Engines created on all workers in {_time.time() - _t0:.2f}s!") async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): - """Initialize engines: create process groups, then load models and setup optimizers.""" + import time as _time + logger.info("Calling engine initialization...") - # Phase 1: Create process groups for distributed training + _t0 = _time.time() + + logger.info("[DiagInit] Phase 1: create_process_group on all workers...") tasks = [ self.scheduler.async_call_engine( worker_id=worker.id, @@ -367,7 +375,13 @@ async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) - # Phase 2: Initialize engines (load models, setup optimizers, etc.) + logger.info( + f"[DiagInit] Phase 1 done: create_process_group completed in " + f"{_time.time() - _t0:.2f}s" + ) + + logger.info("[DiagInit] Phase 2: initialize (load models, setup optimizers) on all workers...") + _t1 = _time.time() tasks = [ self.scheduler.async_call_engine( worker_id=worker.id, @@ -379,6 +393,10 @@ async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) + logger.info( + f"[DiagInit] Phase 2 done: initialize completed in " + f"{_time.time() - _t1:.2f}s (total: {_time.time() - _t0:.2f}s)" + ) logger.info("All engines are initialized!") def _identify_dp_heads(self): diff --git a/areal/infra/rpc/guard/engine_blueprint.py b/areal/infra/rpc/guard/engine_blueprint.py index 4dbd1c765d..c540458e60 100644 --- a/areal/infra/rpc/guard/engine_blueprint.py +++ b/areal/infra/rpc/guard/engine_blueprint.py @@ -105,14 +105,28 @@ def engine_worker(): def _submit_to_engine_thread( func_name: str, func: Callable, *args: Any, **kwargs: Any ) -> Any: - """Submit work to the engine thread and block until result is available.""" global _engine_work_queue _init_engine_thread() + _is_init_op = func_name in ( + "configure", + "create_engine", + "call_create_process_group", + "call_initialize", + "call_connect_engine", + ) + if _is_init_op: + logger.info(f"[DiagInit] _submit_to_engine_thread: submitting '{func_name}'...") + future: Future = Future() _engine_work_queue.put((func, args, kwargs, future, func_name)) - return future.result() # Block until result is available + result = future.result() + + if _is_init_op: + logger.info(f"[DiagInit] _submit_to_engine_thread: '{func_name}' completed") + + return result # --------------------------------------------------------------------------- @@ -142,13 +156,9 @@ def _engine_health_hook() -> dict[str, Any]: def _engine_configure_hook(data: dict) -> dict: - """Handle /configure by setting random seeds in the engine thread. + import time as _time - Raises - ------ - ValueError - If required fields (``config``, ``rank``) are missing. - """ + _t0 = _time.time() config_data = data.get("config") if config_data is None: raise ValueError("Missing 'config' field in request") @@ -157,21 +167,36 @@ def _engine_configure_hook(data: dict) -> dict: if rank is None: raise ValueError("Missing 'rank' field in request") + logger.info(f"[DiagInit] _engine_configure_hook ENTERED (rank={rank})") + + logger.info(f"[DiagInit] _engine_configure_hook rank={rank}: deserializing config...") + _t1 = _time.time() config = deserialize_value(config_data) + logger.info( + f"[DiagInit] _engine_configure_hook rank={rank}: config deserialized in " + f"{_time.time() - _t1:.2f}s" + ) - # Capture role from GuardState (we're in a request context) state = get_state() role = state.role def execute_configure(): + logger.info(f"[DiagInit] execute_configure rank={rank}: setting random seed...") seeding.set_random_seed(config.seed, key=f"{role}{rank}") + logger.info(f"[DiagInit] execute_configure rank={rank}: seed set successfully") return { "status": "success", "message": "Worker configured successful.", "result": None, } - return _submit_to_engine_thread("configure", execute_configure) + logger.info(f"[DiagInit] _engine_configure_hook rank={rank}: submitting to engine thread...") + result = _submit_to_engine_thread("configure", execute_configure) + logger.info( + f"[DiagInit] _engine_configure_hook rank={rank}: COMPLETED in " + f"{_time.time() - _t0:.2f}s" + ) + return result # --------------------------------------------------------------------------- @@ -262,31 +287,19 @@ def execute_set_env(): @engine_bp.route("/create_engine", methods=["POST"]) def create_engine(): - """Create and initialize an engine instance on this worker. + global _engines - This endpoint is routed to the engine thread for serial execution. - Supports multiple engines per worker, keyed by ``engine_name``. + import time as _time - Expected JSON payload:: - - { - "engine": "areal.engine.fsdp_engine.FSDPPPOActor", - "engine_name": "actor/0", - "init_args": [...], - "init_kwargs": {"config": ...} - } - """ - global _engines + _t0 = _time.time() try: - # Parse request in main thread (has Flask request context) data = request.get_json() if data is None: return jsonify({"error": "Invalid JSON in request body"}), 400 engine = data.get("engine") engine_name = data.get("engine_name") - # Deserialize init_args and init_kwargs (may contain tensors/dataclasses) init_args = deserialize_value(data.get("init_args", [])) init_kwargs = deserialize_value(data.get("init_kwargs", {})) @@ -314,11 +327,16 @@ def create_engine(): 400, ) - # Dynamic import (can be done in main thread) + logger.info( + f"[DiagInit] /create_engine ENTERED: engine={engine}, " + f"engine_name={engine_name}" + ) + try: + logger.info(f"[DiagInit] /create_engine {engine_name}: importing engine class...") + _t1 = _time.time() engine_class = import_from_string(engine) - # Validate that the class is a TrainEngine or InferenceEngine if not issubclass(engine_class, TrainEngine) and not issubclass( engine_class, InferenceEngine ): @@ -326,6 +344,10 @@ def create_engine(): "Engine class must be a subclass of TrainEngine or " f"InferenceEngine, got {engine_class}.." ) + logger.info( + f"[DiagInit] /create_engine {engine_name}: engine class imported in " + f"{_time.time() - _t1:.2f}s" + ) except (ValueError, ImportError, AttributeError) as e: logger.error(f"Failed to import engine '{engine}': {e}") return ( @@ -336,14 +358,17 @@ def create_engine(): logger.error(f"Invalid engine type: {e}") return jsonify({"error": str(e)}), 400 - # Instantiate engine in engine thread (may involve NCCL init) def create_engine_in_engine_thread(): - """Create engine in engine thread.""" try: + logger.info( + f"[DiagInit] /create_engine {engine_name}: instantiating " + f"{engine} in engine thread..." + ) + _t2 = _time.time() engine_obj = engine_class(*init_args, **init_kwargs) logger.info( - f"Engine '{engine_name}' (class: {engine}) " - "instantiated successfully" + f"[DiagInit] /create_engine {engine_name}: instantiated in " + f"{_time.time() - _t2:.2f}s" ) return engine_obj except Exception as e: @@ -353,10 +378,17 @@ def create_engine_in_engine_thread(): raise try: + logger.info( + f"[DiagInit] /create_engine {engine_name}: submitting to engine thread..." + ) engine_obj = _submit_to_engine_thread( "create_engine", create_engine_in_engine_thread ) _engines[engine_name] = engine_obj + logger.info( + f"[DiagInit] /create_engine {engine_name}: COMPLETED in " + f"{_time.time() - _t0:.2f}s total" + ) return jsonify( { "status": "success", @@ -380,21 +412,11 @@ def create_engine_in_engine_thread(): @engine_bp.route("/call", methods=["POST"]) def call_engine_method(): - """Call a method on an engine instance. - - This endpoint is routed to the engine thread to ensure all engine - operations run serially in the same thread, preventing NCCL conflicts. + global _engines - Expected JSON payload:: + import time as _time - { - "method": "train_batch", - "engine_name": "actor/0", - "args": [...], - "kwargs": {...} - } - """ - global _engines + _t0 = _time.time() try: data = request.get_json() @@ -429,19 +451,26 @@ def call_engine_method(): 404, ) - # Get the specific engine to call engine = _engines[engine_name] - # Deserialize data raw_args = deserialize_value(raw_args) raw_kwargs = deserialize_value(raw_kwargs) - # Fetch remote tensors args = RTensor.localize(raw_args) kwargs = RTensor.localize(raw_kwargs) + _is_init_method = method_name in ( + "create_process_group", + "initialize", + "connect_engine", + ) + if _is_init_method: + logger.info( + f"[DiagInit] /call ENTERED: method={method_name}, " + f"engine={engine_name}" + ) + def execute_in_engine_thread(): try: - # Broadcast args when engine is a TrainEngine and initialized if isinstance(engine, TrainEngine) and engine.initialized: logger.debug( f"Broadcasting data for TrainEngine method: {method_name}" @@ -463,19 +492,17 @@ def execute_in_engine_thread(): group=engine.context_and_model_parallel_group, ) - args_bcast = tensor_container_to( - args, current_platform.current_device() - ) args_bcast = broadcast_tensor_container( - args_bcast, + tensor_container_to( + args, current_platform.current_device() + ), src_rank=engine.current_data_parallel_head(), group=engine.context_and_model_parallel_group, ) - kwargs_bcast = tensor_container_to( - kwargs, current_platform.current_device() - ) kwargs_bcast = broadcast_tensor_container( - kwargs_bcast, + tensor_container_to( + kwargs, current_platform.current_device() + ), src_rank=engine.current_data_parallel_head(), group=engine.context_and_model_parallel_group, ) @@ -484,10 +511,16 @@ def execute_in_engine_thread(): args_bcast = args kwargs_bcast = kwargs + if _is_init_method: + logger.info( + f"[DiagInit] /call {engine_name}.{method_name}: " + f"executing in engine thread..." + ) + _et0 = _time.time() + logger.debug(f"Calling engine '{engine_name}' method: {method_name}") - # Determine trace category based on method name - category = "misc" # Default category + category = "misc" method_lower = method_name.lower() if any(keyword in method_lower for keyword in ["submit", "wait"]): category = "scheduler" @@ -514,7 +547,6 @@ def execute_in_engine_thread(): ): category = "compute" - # Wrap engine method call with perf_tracer with perf_tracer.trace_scope( f"rpc.{method_name}", category=category, @@ -523,12 +555,17 @@ def execute_in_engine_thread(): method = getattr(engine, method_name) result = method(*args_bcast, **kwargs_bcast) - # Handle update weights future if isinstance(result, Future): logger.debug("Waiting for update weights future") result = result.result() logger.debug("Update weights future done") + if _is_init_method: + logger.info( + f"[DiagInit] /call {engine_name}.{method_name}: " + f"COMPLETED in {_time.time() - _et0:.2f}s" + ) + return result except AttributeError as e: logger.error(f"Method '{method_name}' not found on engine: {e}") @@ -558,7 +595,12 @@ def execute_in_engine_thread(): 500, ) - # Convert all tensors to RTensors and store locally + if _is_init_method: + logger.info( + f"[DiagInit] /call {engine_name}.{method_name}: total " + f"{_time.time() - _t0:.2f}s (including RPC overhead)" + ) + state = get_state() result = RTensor.remotize(result, node_addr=state.node_addr) serialized_result = serialize_value(result) diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index 8c1b9a7a35..ba4b1d36e1 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -834,41 +834,16 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: return worker_ids def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: - """Get workers and wait for them to be ready. + logger.info(f"[DiagInit] get_workers ENTERED for role='{role}'") - Parameters - ---------- - role : str - Worker role name - timeout : float, optional - Maximum time to wait for workers to be ready (None = use default) - - Returns - ------- - list[Worker] - List of Worker objects - - Raises - ------ - WorkerNotFoundError - If role doesn't exist - WorkerFailedError - If any worker process failed - WorkerTimeoutError - If timeout exceeded waiting for workers - """ - # Handle colocated/forked roles if role in self._colocated_roles: - # Forked roles have their own workers in _workers if role not in self._workers: - # Colocated roles delegate to target role's workers target_role = self._colocated_roles[role] logger.debug( f"Role '{role}' is colocated with '{target_role}', " "returning target role's workers" ) return self.get_workers(target_role, timeout) - # Forked roles fall through to normal worker handling below if role not in self._workers: raise WorkerNotFoundError(role) @@ -876,6 +851,10 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: workers = self._workers[role] timeout = timeout if timeout is not None else self.startup_timeout + logger.info( + f"[DiagInit] get_workers role='{role}': checking health of " + f"{len(workers)} workers (timeout={timeout}s)..." + ) self._check_worker_health(role) start_time = time.time() @@ -883,6 +862,11 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: while len(ready_workers) < len(workers): if time.time() - start_time > timeout: + not_ready = [w.worker.id for w in workers if w.worker.id not in ready_workers] + logger.error( + f"[DiagInit] get_workers role='{role}': TIMEOUT after {timeout}s. " + f"Ready: {ready_workers}, NOT ready: {not_ready}" + ) raise WorkerTimeoutError( role, timeout, @@ -892,7 +876,6 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: if worker_info.worker.id in ready_workers: continue - # Forked workers have process=None - skip process check for them if ( worker_info.process is not None and worker_info.process.poll() is not None @@ -911,7 +894,10 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: if len(ready_workers) < len(workers): time.sleep(self.health_check_interval) - logger.info(f"All {len(workers)} workers for role '{role}' are ready") + logger.info( + f"[DiagInit] get_workers role='{role}': all {len(workers)} workers ready " + f"in {time.time() - start_time:.2f}s" + ) return [w.worker for w in workers] def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: @@ -925,31 +911,55 @@ def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: return False def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): + worker_id = worker_info.worker.id + logger.info(f"[DiagInit] _configure_worker ENTERED for {worker_id} (rank={worker_rank})") + + logger.info(f"[DiagInit] _configure_worker {worker_id}: waiting for /health endpoint...") + _t0 = time.time() while not self._is_worker_ready(worker_info): time.sleep(0.1) + logger.info( + f"[DiagInit] _configure_worker {worker_id}: /health ready in {time.time() - _t0:.2f}s" + ) - worker_id = worker_info.worker.id port = int(worker_info.worker.worker_ports[0]) url = f"http://{format_hostport(worker_info.worker.ip, port)}/configure" try: + logger.info(f"[DiagInit] _configure_worker {worker_id}: serializing config payload...") + _t1 = time.time() + payload_data = orjson.dumps( + serialize_value( + dict( + config=self.exp_config, + role=worker_info.role, + rank=worker_rank, + ) + ) + ) + logger.info( + f"[DiagInit] _configure_worker {worker_id}: config serialized in " + f"{time.time() - _t1:.2f}s, payload_size={len(payload_data)} bytes" + ) + + logger.info( + f"[DiagInit] _configure_worker {worker_id}: sending POST /configure " + f"to {url} (timeout=300s)..." + ) + _t2 = time.time() response = requests.post( url, - data=orjson.dumps( - serialize_value( - dict( - config=self.exp_config, - role=worker_info.role, - rank=worker_rank, - ) - ) - ), + data=payload_data, headers={"Content-Type": "application/json"}, timeout=300.0, ) + logger.info( + f"[DiagInit] _configure_worker {worker_id}: POST /configure responded " + f"in {time.time() - _t2:.2f}s with status={response.status_code}" + ) if response.status_code == 200: - logger.info(f"Configuration successfully on worker '{worker_id}'") + logger.info(f"[DiagInit] _configure_worker {worker_id}: Configuration successful") return elif response.status_code == 400: error_detail = response.json().get("error", "Unknown error") @@ -1180,9 +1190,11 @@ async def create_engine( url = f"http://{format_hostport(worker_info.worker.ip, port)}/create_engine" try: - logger.debug( - f"Creating engine '{engine_name}' (class: {engine}) on worker '{worker_id}'" + logger.info( + f"[DiagInit] create_engine: sending POST /create_engine to " + f"worker '{worker_id}' (engine={engine}, engine_name={engine_name})..." ) + _t0 = time.time() timeout = aiohttp.ClientTimeout(total=300.0) async with aiohttp.ClientSession( @@ -1197,8 +1209,9 @@ async def create_engine( ) as response: if response.status == 200: result = await response.json() - logger.debug( - f"Engine '{engine_name}' created successfully on worker '{worker_id}'" + logger.info( + f"[DiagInit] create_engine: engine '{engine_name}' " + f"created on worker '{worker_id}' in {time.time() - _t0:.2f}s" ) return result.get("result") elif response.status == 400: @@ -1455,6 +1468,18 @@ async def async_call_engine( ) try: + _is_init_method = method in ( + "create_process_group", + "initialize", + "connect_engine", + ) + if _is_init_method: + logger.info( + f"[DiagInit] async_call_engine: sending POST /call " + f"(method={method}, worker={worker_id}, engine={engine_name})..." + ) + _t0 = time.time() + logger.debug( f"Async calling method '{method}' on worker '{worker_id}' (attempt {attempt})" ) @@ -1476,6 +1501,12 @@ async def async_call_engine( if response.status == 200: result_data = (await response.json()).get("result") deserialized_result = deserialize_value(result_data) + if _is_init_method: + logger.info( + f"[DiagInit] async_call_engine: POST /call " + f"(method={method}, worker={worker_id}) " + f"completed in {time.time() - _t0:.2f}s" + ) if attempt > 1: logger.debug( f"Method '{method}' succeeded on worker '{worker_id}' " diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index d8e93cbc5e..68729ae6c1 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -100,54 +100,63 @@ def __init__( train_dataset: Dataset | None = None, valid_dataset: Dataset | None = None, ): + import time as _time + + _t0 = _time.time() rank = int(os.getenv("RANK", "0")) if is_single_controller(): - # Set up file logging for controller process logging.setup_file_logging(StatsLogger.get_log_path(config.stats_logger)) self.config = config + logger.info("[DiagInit] PPOTrainer.__init__: loading tokenizer...") + _t1 = _time.time() self.processor, self.tokenizer = load_hf_processor_and_tokenizer( config.tokenizer_path ) + logger.info(f"[DiagInit] PPOTrainer.__init__: tokenizer loaded in {_time.time() - _t1:.2f}s") + self.scheduler = None if is_single_controller(): + logger.info("[DiagInit] PPOTrainer.__init__: initializing scheduler...") self.scheduler = self._init_scheduler() - # Set seed. seeding.set_random_seed(config.seed, key=f"trainer{rank}") - # Parse per-engine allocations from config. self.actor_alloc = ModelAllocation.from_str(config.actor.backend, name="actor") self.rollout_alloc = ModelAllocation.from_str( config.rollout.backend, name="rollout" ) - # Validate config before proceeding with weight initialization self._validate_cfg() - self._amend_xccl_weight_update_envvar() - # Create models: actor, critic, ref — each with its own allocation. + logger.info("[DiagInit] PPOTrainer.__init__: creating actor engine...") + _t2 = _time.time() self.actor = self._create_train_engine(config.actor, self.actor_alloc) + logger.info(f"[DiagInit] PPOTrainer.__init__: actor engine created in {_time.time() - _t2:.2f}s") + self.critic = None if config.critic is not None: critic_alloc = ModelAllocation.from_str( config.critic.backend, name="critic" ) + logger.info("[DiagInit] PPOTrainer.__init__: creating critic engine...") + _t_crit = _time.time() self.critic = self._create_critic(config.critic, critic_alloc) + logger.info(f"[DiagInit] PPOTrainer.__init__: critic engine created in {_time.time() - _t_crit:.2f}s") self.ref = None if config.actor.kl_ctl > 0 and config.ref is not None: ref_alloc = ModelAllocation.from_str(config.ref.backend, name="ref") + logger.info("[DiagInit] PPOTrainer.__init__: creating ref engine...") + _t_ref = _time.time() self.ref = self._create_train_engine(config.ref, ref_alloc) + logger.info(f"[DiagInit] PPOTrainer.__init__: ref engine created in {_time.time() - _t_ref:.2f}s") - # Create dataloaders + logger.info("[DiagInit] PPOTrainer.__init__: creating dataloaders...") + _t3 = _time.time() self.train_dataset = train_dataset self.valid_dataset = valid_dataset if train_dataset is None: - # Online mode: require total_train_steps to compute steps_per_epoch. - # Without this, __len__()=1 causes every step to be treated as an - # epoch boundary, making Saver/RecoverHandler fire every step and - # corrupting the LR schedule. if config.total_train_steps is None: raise ValueError( "total_train_steps must be set for online mode " @@ -180,6 +189,7 @@ def __init__( rank=self.actor.data_parallel_rank, world_size=self.actor.data_parallel_world_size, ) + logger.info(f"[DiagInit] PPOTrainer.__init__: dataloaders created in {_time.time() - _t3:.2f}s") ft_spec = FinetuneSpec( total_train_epochs=config.total_train_epochs, @@ -187,12 +197,22 @@ def __init__( train_batch_size=config.train_dataset.batch_size, ) + logger.info("[DiagInit] PPOTrainer.__init__: initializing actor engine (workers, model loading)...") + _t4 = _time.time() engine_init_kwargs = {"addr": None, "ft_spec": ft_spec} self.actor.initialize(**engine_init_kwargs, role="actor") + logger.info(f"[DiagInit] PPOTrainer.__init__: actor engine initialized in {_time.time() - _t4:.2f}s") + if self.critic is not None: + logger.info("[DiagInit] PPOTrainer.__init__: initializing critic engine...") + _t_crit2 = _time.time() self.critic.initialize(**engine_init_kwargs, role="critic") + logger.info(f"[DiagInit] PPOTrainer.__init__: critic engine initialized in {_time.time() - _t_crit2:.2f}s") if self.ref is not None: + logger.info("[DiagInit] PPOTrainer.__init__: initializing ref engine...") + _t_ref2 = _time.time() self.ref.initialize(**engine_init_kwargs, role="ref") + logger.info(f"[DiagInit] PPOTrainer.__init__: ref engine initialized in {_time.time() - _t_ref2:.2f}s") self.teacher = None if config.teacher is not None: @@ -202,14 +222,18 @@ def __init__( self.teacher = self._create_train_engine(config.teacher, teacher_alloc) self.teacher.initialize(**engine_init_kwargs, role="teacher") - # Save initial LoRA weights if enabled (for inference server pre-loading) + logger.info("[DiagInit] PPOTrainer.__init__: saving initial LoRA weights...") + _t5 = _time.time() initial_lora_path = self._save_initial_lora_weights() + logger.info(f"[DiagInit] PPOTrainer.__init__: LoRA weights saved in {_time.time() - _t5:.2f}s") - # Initialize inference with LoRA path + logger.info("[DiagInit] PPOTrainer.__init__: initializing rollout engine...") + _t6 = _time.time() self.rollout = self._init_rollout( config.rollout, is_eval=False, lora_path=initial_lora_path ) - # Online mode detection: skip eval rollout for efficiency. + logger.info(f"[DiagInit] PPOTrainer.__init__: rollout engine initialized in {_time.time() - _t6:.2f}s") + openai_cfg = config.rollout.openai self._online_mode = train_dataset is None or ( openai_cfg is not None and openai_cfg.mode == "online" @@ -217,14 +241,17 @@ def __init__( self.eval_rollout = None if not self._online_mode: + logger.info("[DiagInit] PPOTrainer.__init__: initializing eval rollout...") + _t_eval = _time.time() self.eval_rollout = self._init_rollout( config.rollout, is_eval=True, lora_path=initial_lora_path ) + logger.info(f"[DiagInit] PPOTrainer.__init__: eval rollout initialized in {_time.time() - _t_eval:.2f}s") - # Proxy worker initialization (lazy, for AgentWorkflow support) self._proxy_started = False - # Prepare weight update meta and connect to inference engine + logger.info("[DiagInit] PPOTrainer.__init__: preparing weight update meta...") + _t7 = _time.time() if self.config.actor.weight_update_mode == "disk": disk_kwargs = { "experiment_name": config.experiment_name, @@ -243,7 +270,6 @@ def __init__( ) self.weight_update_meta = WeightUpdateMeta.from_disk(**disk_kwargs) elif self.config.actor.weight_update_mode == "xccl": - # NCCL/XCCL weight update xccl_kwargs: dict[str, Any] = { "gen_allocation": self.rollout_alloc, } @@ -267,20 +293,20 @@ def __init__( raise ValueError( f"Invalid weight update mode: {self.config.actor.weight_update_mode}" ) + logger.info(f"[DiagInit] PPOTrainer.__init__: weight update meta prepared in {_time.time() - _t7:.2f}s") + logger.info("[DiagInit] PPOTrainer.__init__: connecting actor to rollout engine...") + _t8 = _time.time() self.actor.connect_engine(self.rollout, self.weight_update_meta) + logger.info(f"[DiagInit] PPOTrainer.__init__: actor connected to rollout in {_time.time() - _t8:.2f}s") - # Set up evaluation (skip in online mode) self.evaluator = Evaluator(config.evaluator, ft_spec) - - # Set up save as HF model self.saver = Saver(config.saver, ft_spec) self.recover_handler = RecoverHandler(config.recover, ft_spec) - - # Set up statistics logging (wandb, tensoboard, etc.) self.stats_logger = StatsLogger(config, ft_spec) - # Set up checkpointing for recover + logger.info("[DiagInit] PPOTrainer.__init__: loading recover checkpoint...") + _t9 = _time.time() self.recover_info = self.recover_handler.load( self.actor, self.saver, @@ -290,9 +316,14 @@ def __init__( inference_engine=self.rollout, weight_update_meta=self.weight_update_meta, ) + logger.info(f"[DiagInit] PPOTrainer.__init__: recover checkpoint loaded in {_time.time() - _t9:.2f}s") self._config_perf_tracer() + logger.info( + f"[DiagInit] PPOTrainer.__init__: COMPLETED in {_time.time() - _t0:.2f}s total" + ) + def train( self, workflow: WorkflowLike | None = None, From 5feca788d1f3552df890dab287eca7dc6561443b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 27 Apr 2026 16:13:39 +0800 Subject: [PATCH 075/140] =?UTF-8?q?fix(scheduler):=20worker=20check?= =?UTF-8?q?=E3=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- areal/infra/scheduler/local.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index ba4b1d36e1..03778f860f 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -907,7 +907,7 @@ def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: try: response = requests.get(url, timeout=2.0) return response.status_code == 200 - except Exception: + except Exception as e: return False def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): @@ -916,7 +916,29 @@ def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): logger.info(f"[DiagInit] _configure_worker {worker_id}: waiting for /health endpoint...") _t0 = time.time() + _health_check_count = 0 while not self._is_worker_ready(worker_info): + _health_check_count += 1 + if _health_check_count % 50 == 1: + port = int(worker_info.worker.worker_ports[0]) + url = f"http://{format_hostport(worker_info.worker.ip, port)}/health" + elapsed = time.time() - _t0 + logger.info( + f"[DiagInit] _configure_worker {worker_id}: /health still not ready " + f"after {elapsed:.1f}s (url={url}, checks={_health_check_count}). " + f"Probing with detailed error..." + ) + try: + resp = requests.get(url, timeout=2.0) + logger.info( + f"[DiagInit] _configure_worker {worker_id}: probe got " + f"status={resp.status_code}, body={resp.text[:200]}" + ) + except Exception as probe_err: + logger.warning( + f"[DiagInit] _configure_worker {worker_id}: probe failed: " + f"{type(probe_err).__name__}: {probe_err}" + ) time.sleep(0.1) logger.info( f"[DiagInit] _configure_worker {worker_id}: /health ready in {time.time() - _t0:.2f}s" From a5177cc2a2514ac2c307efaf457d459045c324a9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 27 Apr 2026 17:26:59 +0800 Subject: [PATCH 076/140] fix(infra): fix net --- areal/infra/launcher/sglang_server.py | 2 ++ areal/infra/launcher/vllm_server.py | 2 ++ areal/infra/remote_inf_engine.py | 4 ++++ areal/infra/rpc/rtensor.py | 1 + areal/infra/scheduler/local.py | 18 ++++++++++++++++-- areal/infra/scheduler/slurm.py | 20 +++++++++++++++++--- areal/infra/utils/http.py | 13 ++++++++++++- areal/infra/workflow_context.py | 1 + areal/infra/workflow_executor.py | 2 ++ 9 files changed, 57 insertions(+), 6 deletions(-) diff --git a/areal/infra/launcher/sglang_server.py b/areal/infra/launcher/sglang_server.py index 90ce8a10b8..c75e2ac3a5 100644 --- a/areal/infra/launcher/sglang_server.py +++ b/areal/infra/launcher/sglang_server.py @@ -73,6 +73,8 @@ def wait_for_server(base_url: str, timeout: int | None = None) -> None: response = requests.get( f"{base_url}/v1/models", headers={"Authorization": "Bearer None"}, + proxies={"http": None, "https": None}, + verify=False, ) if response.status_code == 200: time.sleep(5) diff --git a/areal/infra/launcher/vllm_server.py b/areal/infra/launcher/vllm_server.py index fc2074bad2..6a36d1c00d 100644 --- a/areal/infra/launcher/vllm_server.py +++ b/areal/infra/launcher/vllm_server.py @@ -72,6 +72,8 @@ def wait_for_server(base_url: str, timeout: int | None = None) -> None: response = requests.get( f"{base_url}/v1/models", headers={"Authorization": "Bearer None"}, + proxies={"http": None, "https": None}, + verify=False, ) if response.status_code == 200: time.sleep(5) diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 7890be4226..71e70b0b3a 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -1364,6 +1364,7 @@ async def _fn(): async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self.config.request_timeout), read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) as session: jobs = [] @@ -1477,6 +1478,7 @@ async def _fn(): async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=request_timeout), read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) as session: for http_req in weight_reqs.requests: @@ -1523,6 +1525,7 @@ async def _fn(): async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=request_timeout), read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) as session: jobs = [] @@ -1580,6 +1583,7 @@ async def _fn(): async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=request_timeout), read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) as session: for req_idx, http_req in enumerate(weight_reqs.requests): diff --git a/areal/infra/rpc/rtensor.py b/areal/infra/rpc/rtensor.py index 899c408e35..afb6d9349b 100644 --- a/areal/infra/rpc/rtensor.py +++ b/areal/infra/rpc/rtensor.py @@ -98,6 +98,7 @@ def _create_session(self) -> aiohttp.ClientSession: return aiohttp.ClientSession( timeout=timeout, read_bufsize=10 * 1024 * 1024, # 10MB buffer + trust_env=False, connector=get_default_connector(), ) diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index 03778f860f..9689329b47 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -53,6 +53,9 @@ logger = logging.getLogger("LocalScheduler") +_NO_PROXY = {"http": None, "https": None} +_NO_PROXY_TRUST_ENV = False + @dataclass class WorkerInfo: @@ -466,6 +469,7 @@ async def _cleanup_forked_workers_async( timeout = aiohttp.ClientTimeout(total=30.0) async with aiohttp.ClientSession( timeout=timeout, + trust_env=False, connector=get_default_connector(), ) as session: tasks = [] @@ -497,6 +501,7 @@ async def _create_forked_workers_async( timeout = aiohttp.ClientTimeout(total=120.0) async with aiohttp.ClientSession( timeout=timeout, + trust_env=False, connector=get_default_connector(), ) as session: # Launch all fork requests concurrently with exception handling @@ -905,7 +910,9 @@ def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: url = f"http://{format_hostport(worker_info.worker.ip, port)}/health" try: - response = requests.get(url, timeout=2.0) + response = requests.get( + url, timeout=2.0, proxies=_NO_PROXY, verify=False + ) return response.status_code == 200 except Exception as e: return False @@ -929,7 +936,9 @@ def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): f"Probing with detailed error..." ) try: - resp = requests.get(url, timeout=2.0) + resp = requests.get( + url, timeout=2.0, proxies=_NO_PROXY, verify=False + ) logger.info( f"[DiagInit] _configure_worker {worker_id}: probe got " f"status={resp.status_code}, body={resp.text[:200]}" @@ -974,6 +983,8 @@ def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): data=payload_data, headers={"Content-Type": "application/json"}, timeout=300.0, + proxies=_NO_PROXY, + verify=False, ) logger.info( f"[DiagInit] _configure_worker {worker_id}: POST /configure responded " @@ -1124,6 +1135,7 @@ async def set_worker_env(self, worker_id: str, env: dict[str, str]) -> None: timeout = aiohttp.ClientTimeout(total=30.0) async with aiohttp.ClientSession( timeout=timeout, + trust_env=False, connector=get_default_connector(), ) as session: async with session.post( @@ -1222,6 +1234,7 @@ async def create_engine( async with aiohttp.ClientSession( timeout=timeout, read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) as session: async with session.post( @@ -1512,6 +1525,7 @@ async def async_call_engine( async with aiohttp.ClientSession( timeout=timeo, read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) as session: async with session.post( diff --git a/areal/infra/scheduler/slurm.py b/areal/infra/scheduler/slurm.py index 16ef5402d9..1f4e070c57 100644 --- a/areal/infra/scheduler/slurm.py +++ b/areal/infra/scheduler/slurm.py @@ -53,6 +53,8 @@ logger = logging.getLogger("SlurmScheduler") +_NO_PROXY = {"http": None, "https": None} + @dataclass class SlurmWorkerInfo: @@ -277,7 +279,7 @@ def _is_worker_ready(self, worker_info: SlurmWorkerInfo) -> bool: url = f"http://{format_hostport(worker_info.worker.ip, port)}/health" try: - response = requests.get(url, timeout=2.0) + response = requests.get(url, timeout=2.0, proxies=_NO_PROXY, verify=False) return response.status_code == 200 except Exception: return False @@ -305,6 +307,8 @@ def _configure_worker(self, worker_info: SlurmWorkerInfo, worker_rank: int) -> N ), headers={"Content-Type": "application/json"}, timeout=300.0, + proxies=_NO_PROXY, + verify=False, ) if response.status_code == 200: @@ -367,8 +371,10 @@ def _discover_worker_network(self, role: str) -> None: # Allocate new ports from the worker if worker_info.spec.port_count > 1: resp = requests.post( - f"http://{format_hostport(ip, port)}/alloc_ports", + f"http://{format_hostport(ip, port)}/alloc_ports", json=dict(count=worker_info.spec.port_count - 1), + proxies=_NO_PROXY, + verify=False, ) resp.raise_for_status() worker_ports += list(map(str, resp.json()["ports"])) @@ -657,6 +663,7 @@ async def _cleanup_forked_workers_async( timeout = aiohttp.ClientTimeout(total=30.0) async with aiohttp.ClientSession( timeout=timeout, + trust_env=False, connector=get_default_connector(), ) as session: tasks = [] @@ -688,6 +695,7 @@ async def _create_forked_workers_async( timeout = aiohttp.ClientTimeout(total=120.0) async with aiohttp.ClientSession( timeout=timeout, + trust_env=False, connector=get_default_connector(), ) as session: # Launch all fork requests concurrently with exception handling @@ -1307,6 +1315,7 @@ async def set_worker_env(self, worker_id: str, env: dict[str, str]) -> None: timeout = aiohttp.ClientTimeout(total=30.0) async with aiohttp.ClientSession( timeout=timeout, + trust_env=False, connector=get_default_connector(), ) as session: async with session.post( @@ -1397,6 +1406,7 @@ async def create_engine( async with aiohttp.ClientSession( timeout=timeout, read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) as session: async with session.post( @@ -1516,7 +1526,10 @@ def call_engine( raise try: - response = requests.post(url, json=payload, timeout=http_timeout) + response = requests.post( + url, json=payload, timeout=http_timeout, + proxies=_NO_PROXY, verify=False, + ) if response.status_code == 200: result = response.json() @@ -1651,6 +1664,7 @@ async def async_call_engine( async with aiohttp.ClientSession( timeout=timeout, read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) as session: async with session.post( diff --git a/areal/infra/utils/http.py b/areal/infra/utils/http.py index aa76abb835..f92eedcee0 100644 --- a/areal/infra/utils/http.py +++ b/areal/infra/utils/http.py @@ -15,7 +15,17 @@ def get_default_connector(): - return aiohttp.TCPConnector(limit=0, use_dns_cache=False, force_close=True) + return aiohttp.TCPConnector( + limit=0, use_dns_cache=False, force_close=True + ) + + +def get_default_session_kwargs(**overrides): + return { + "trust_env": False, + "connector": get_default_connector(), + **overrides, + } async def arequest_with_retry( @@ -49,6 +59,7 @@ async def arequest_with_retry( _session = aiohttp.ClientSession( timeout=timeo, read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) else: diff --git a/areal/infra/workflow_context.py b/areal/infra/workflow_context.py index a828eaf617..a78f4979f9 100644 --- a/areal/infra/workflow_context.py +++ b/areal/infra/workflow_context.py @@ -128,6 +128,7 @@ async def get_aiohttp_session(self) -> aiohttp.ClientSession: self._aiohttp_session = aiohttp.ClientSession( timeout=timeout, read_bufsize=1024 * 1024 * 10, + trust_env=False, connector=get_default_connector(), ) # Track which event loop this session belongs to diff --git a/areal/infra/workflow_executor.py b/areal/infra/workflow_executor.py index ca3e50610f..a5a9e90a49 100644 --- a/areal/infra/workflow_executor.py +++ b/areal/infra/workflow_executor.py @@ -344,6 +344,8 @@ def post(): addr, json={"task_id": task_id}, timeout=30, + proxies={"http": None, "https": None}, + verify=False, ) resp.raise_for_status() except requests.RequestException as e: From 4e06ba9298056249a0aa075b08dc3d62b62ae430 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 27 Apr 2026 18:17:43 +0800 Subject: [PATCH 077/140] fix(net): add callback(need rethink) --- areal/utils/network.py | 45 +++++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/areal/utils/network.py b/areal/utils/network.py index cb7b1ff791..043d154798 100644 --- a/areal/utils/network.py +++ b/areal/utils/network.py @@ -17,6 +17,9 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: Returns: The selected local IP address as a string. Supports both IPv4 and IPv6. + Falls back to ``127.0.0.1`` if the detected IP cannot be bound + (e.g. inside a Docker container where the external IP is not assigned + to any local interface). Raises: RuntimeError: If no suitable address can be determined @@ -28,27 +31,45 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: if family == socket.AF_INET: ip = sockaddr[0] if ip and not ip.startswith("127."): - return ip + if _can_bind(ip): + return ip elif family == socket.AF_INET6: ip = sockaddr[0] if ip and ip != "::1": - return ip + if _can_bind(ip): + return ip except socket.gaierror: pass try: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: sock.connect((probe_host, probe_port)) - return sock.getsockname()[0] - except OSError as e: - try: - with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: - sock.connect(("2001:4860:4860::8888", probe_port)) - ip6 = sock.getsockname()[0] - if ip6 and ip6 != "::1": - return ip6 - except OSError: - raise RuntimeError("Could not determine host IP") from e + ip = sock.getsockname()[0] + if _can_bind(ip): + return ip + except OSError: + pass + + try: + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: + sock.connect(("2001:4860:4860::8888", probe_port)) + ip6 = sock.getsockname()[0] + if ip6 and ip6 != "::1" and _can_bind(ip6): + return ip6 + except OSError: + pass + + return "127.0.0.1" + + +def _can_bind(ip: str) -> bool: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((ip, 0)) + return True + except OSError: + return False def get_loopback_ip() -> str: From 63497c7e81dfb4611d69f274660d5770047decf0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 00:21:13 +0800 Subject: [PATCH 078/140] fix(engine): double scale --- areal/engine/megatron_engine.py | 17 +++++++++++++++++ examples/math/gsm8k_grpo_megatron_mimo.yaml | 1 + 2 files changed, 18 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index d6edcfd5ea..0a1251cc2c 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -3592,6 +3592,23 @@ def _compute_logprobs_and_loss( loss_scale, ) + if _mtp_loss_for_this_mb is not None and abs(loss_scale) > 0: + _inv = 1.0 / loss_scale + # Subtract the already-added mtp and re-add with inverse scaling + # so `(loss) * loss_scale == loss_rl * loss_scale + mtp`. + loss = (loss - _mtp_contribution) + _mtp_contribution * _inv + _n_ds = self._mtp_loss_total_count + if _n_ds <= 4 or _n_ds % 100 == 0: + self.logger.info( + "[MTPFix-DoubleScale] Inverse-loss_scale applied: " + "loss_scale=%.6f, inv=%.4f, mtp_contribution=%.6f, " + "effective_mtp_in_final_loss=%.6f (verl-equivalent, " + "single mtp_loss_scaling_factor application)", + loss_scale, _inv, + _mtp_contribution.detach().item(), + _mtp_contribution.detach().item(), + ) + return loss * loss_scale def _compute_forward_result( diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index b7624a2cc6..e47581b082 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -82,6 +82,7 @@ actor: enable_mtp_training: true mtp_num_layers: 1 mtp_loss_scaling_factor: 0.2 + mtp_lr_scale: 10.0 scheduling_spec: - task_type: worker From 5b84634cf03e4d247f3d370866f6516368890ced Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 00:28:12 +0800 Subject: [PATCH 079/140] feat(actor): fix mtp_lr_scale --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index e47581b082..550ea6f53d 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -60,6 +60,7 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 + mtp_lr_scale: 10.0 eps_clip: 0.4 temperature: ${gconfig.temperature} reward_scaling: 10.0 @@ -82,7 +83,6 @@ actor: enable_mtp_training: true mtp_num_layers: 1 mtp_loss_scaling_factor: 0.2 - mtp_lr_scale: 10.0 scheduling_spec: - task_type: worker From 056724cdd6137b915837af9a0c072020b8403b97 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 14:06:36 +0800 Subject: [PATCH 080/140] fix(engine): fix mtp gradient numbatch --- areal/engine/megatron_engine.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 0a1251cc2c..d5ca29ce4d 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1896,6 +1896,11 @@ def train_batch( mb_list, loss_weight_fn, mpu.get_data_parallel_group() ) + # Expose num_microbatches to _compute_logprobs_and_loss so + # the DoubleScale inversion can further divide the MTP contribution by + # num_mb. + self._current_num_microbatches = int(len(mb_list)) + # Step 3: Forward-backward using Megatron's pipeline function loss_multiplier = ( mpu.get_data_parallel_world_size() * self.optimizer.get_loss_scale().item() @@ -3593,20 +3598,29 @@ def _compute_logprobs_and_loss( ) if _mtp_loss_for_this_mb is not None and abs(loss_scale) > 0: - _inv = 1.0 / loss_scale - # Subtract the already-added mtp and re-add with inverse scaling - # so `(loss) * loss_scale == loss_rl * loss_scale + mtp`. + # Match Megatron-native MTPLossAutoScaler: + # schedules.py sets main_loss_backward_scale = loss_scale + # / num_microbatches. + _num_mb = max(1, int(getattr(self, "_current_num_microbatches", 1))) + _inv = 1.0 / (loss_scale * _num_mb) + # Subtract already-added mtp and re-add with corrected scaling so + # `loss * loss_scale` contributes (mtp_loss_scale * mtp_loss) / + # num_mb per microbatch loss = (loss - _mtp_contribution) + _mtp_contribution * _inv _n_ds = self._mtp_loss_total_count if _n_ds <= 4 or _n_ds % 100 == 0: + _eff_per_mb = ( + _mtp_contribution.detach().item() * _inv * loss_scale + ) self.logger.info( - "[MTPFix-DoubleScale] Inverse-loss_scale applied: " - "loss_scale=%.6f, inv=%.4f, mtp_contribution=%.6f, " - "effective_mtp_in_final_loss=%.6f (verl-equivalent, " - "single mtp_loss_scaling_factor application)", - loss_scale, _inv, - _mtp_contribution.detach().item(), + "[MTPFix-DoubleScale-v6] Inverse-(loss_scale*num_mb) " + "applied: loss_scale=%.6f, num_mb=%d, inv=%.4f, " + "mtp_contribution=%.6f, effective_mtp_contrib_per_mb=" + "%.6f (accumulated over num_mb MBs = mtp_loss_scale * " + "mtp_loss; verl/megatron-native equivalent).", + loss_scale, _num_mb, _inv, _mtp_contribution.detach().item(), + _eff_per_mb, ) return loss * loss_scale From 8b906666a109c9c9c85966f13c711b220b55bff2 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 15:56:50 +0800 Subject: [PATCH 081/140] fix(engine): lr --- areal/engine/megatron_engine.py | 68 ++++++++++++++++----- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index d5ca29ce4d..8c0b197f32 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1202,18 +1202,20 @@ def forward_step(batch_iter, model): self._mtp_loss_clip_count = 0 self._mtp_loss_total_count = 0 # [v5-F6] Hint SpecDec v2 env toggle for throughput (idempotent, - # rank-0 only to avoid N-rank log spam). + # rank-0 only to avoid N-rank log spam, print once only). import os as _os_v5 try: _rank_v5 = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 except Exception: _rank_v5 = 0 if _rank_v5 == 0 and _os_v5.environ.get("SGLANG_ENABLE_SPEC_V2", "") == "": - self.logger.info( - "[MTPEnvHint] SGLANG_ENABLE_SPEC_V2 not set; " - "consider exporting SGLANG_ENABLE_SPEC_V2=True to " - "enable overlap scheduler for speculative decoding." - ) + if not getattr(self, '_mtp_env_hint_printed', False): + self._mtp_env_hint_printed = True + self.logger.info( + "[MTPEnvHint] SGLANG_ENABLE_SPEC_V2 not set; " + "consider exporting SGLANG_ENABLE_SPEC_V2=True to " + "enable overlap scheduler for speculative decoding." + ) _unwrapped = model while hasattr(_unwrapped, "module"): @@ -3048,8 +3050,20 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _d = _cur - _prev _all_norms.append((_tn, _cur, _d)) _deltas.append(abs(_d)) - # Stall: weight changed by <1e-5 absolute (essentially frozen). - if _cur > 0 and abs(_d) < 1e-5: + # LR-adaptive STALL threshold. bf16-eps is ~7.8e-3 per + # element and typical lr*grad_norm for MTP is ~1e-7..1e-6, + # so the previous 1e-5 absolute threshold mis-flagged every + # LN/bias tensor as "stalled" + try: + _mtp_lr_cur = float( + getattr(self, "_last_logged_mtp_lr", 3e-6) + ) + except Exception: + _mtp_lr_cur = 3e-6 + # Expected per-step drift ~ lr * grad_norm; anything + # <5% of that is truly frozen. + _stall_thr = max(1e-9, 0.05 * _mtp_lr_cur * max(_cur, 1.0)) + if _cur > 0 and abs(_d) < _stall_thr: _stall_tensors.append(_tn) self._mtp_sync_prev_norms[_tn] = _cur # Compact per-tensor summary line (rank-0 only to avoid DP-spam). @@ -3077,9 +3091,11 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ", ".join(_fmt_parts), _drift_summary, ) - # [v5-F5] Stall warning: if >50% of MTP tensors show sub-1e-5 drift, - # the draft model isn't learning — root cause of accept-rate collapse. - if _deltas and len(_stall_tensors) >= 0.5 * len(_deltas): + # Stall warning: require >=90% (was 50%) AND LR-adaptive + # threshold above. Previously bf16 noise floor mis-fired every + # step; the new combined criterion only fires if the MTP + # optimizer is truly dead. + if _deltas and len(_stall_tensors) >= 0.9 * len(_deltas): self.logger.warning( "[MTPSyncHealth] MTP training STALL detected at version=%d: " "%d/%d tensors drift<1e-5. " @@ -3598,6 +3614,17 @@ def _compute_logprobs_and_loss( ) if _mtp_loss_for_this_mb is not None and abs(loss_scale) > 0: + # [v8] Refresh cached MTP LR from optimizer param_groups so the + # DoubleScale log and SyncHealth STALL threshold can use the + # realised LR (not a hardcoded default). + try: + for _pg in getattr(self.optimizer, "param_groups", []): + _nm = str(_pg.get("name", "")) + if "mtp" in _nm.lower(): + self._last_logged_mtp_lr = float(_pg.get("lr", 3e-6)) + break + except Exception: + pass # Match Megatron-native MTPLossAutoScaler: # schedules.py sets main_loss_backward_scale = loss_scale # / num_microbatches. @@ -3612,15 +3639,28 @@ def _compute_logprobs_and_loss( _eff_per_mb = ( _mtp_contribution.detach().item() * _inv * loss_scale ) + # Also surface the realised per-step MTP weight update + # magnitude estimate (= eff_contrib * mtp_lr). This directly + # monitors whether the draft head is actually learning, and + # its drift exposes data-shape driven instability + try: + _mtp_lr_dbg = float( + getattr(self, "_last_logged_mtp_lr", 3e-6) + ) + except Exception: + _mtp_lr_dbg = 3e-6 + _eff_step_mag = _eff_per_mb * _mtp_lr_dbg self.logger.info( "[MTPFix-DoubleScale-v6] Inverse-(loss_scale*num_mb) " "applied: loss_scale=%.6f, num_mb=%d, inv=%.4f, " "mtp_contribution=%.6f, effective_mtp_contrib_per_mb=" - "%.6f (accumulated over num_mb MBs = mtp_loss_scale * " - "mtp_loss; verl/megatron-native equivalent).", + "%.6f, mtp_lr=%.3e, effective_per_step_update~=%.3e " + "(warn if <1e-8; accumulated over num_mb MBs = " + "mtp_loss_scale * mtp_loss; verl/megatron-native " + "equivalent).", loss_scale, _num_mb, _inv, _mtp_contribution.detach().item(), - _eff_per_mb, + _eff_per_mb, _mtp_lr_dbg, _eff_step_mag, ) return loss * loss_scale diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 550ea6f53d..6b09adf474 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -60,7 +60,7 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 - mtp_lr_scale: 10.0 + mtp_lr_scale: 1.0 eps_clip: 0.4 temperature: ${gconfig.temperature} reward_scaling: 10.0 From be8c1b0f31d5a4da8a9dfa0dec500096ece2cc55 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 17:27:10 +0800 Subject: [PATCH 082/140] feat(engine): megatron log --- areal/engine/megatron_engine.py | 107 +++++++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 15 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 8c0b197f32..7b99e03205 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1902,6 +1902,18 @@ def train_batch( # the DoubleScale inversion can further divide the MTP contribution by # num_mb. self._current_num_microbatches = int(len(mb_list)) + # expose total token count for [MTPDataShapeDiag-v9] so + # tokens_per_mb can be logged and correlated with accept_rate + # regressions + try: + _tot = 0 + for _mb in mb_list: + _ids = _mb.get("input_ids") if isinstance(_mb, dict) else None + if _ids is not None and hasattr(_ids, "numel"): + _tot += int(_ids.numel()) + self._current_n_tokens = _tot + except Exception: + self._current_n_tokens = 0 # Step 3: Forward-backward using Megatron's pipeline function loss_multiplier = ( @@ -3050,19 +3062,35 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _d = _cur - _prev _all_norms.append((_tn, _cur, _d)) _deltas.append(abs(_d)) - # LR-adaptive STALL threshold. bf16-eps is ~7.8e-3 per - # element and typical lr*grad_norm for MTP is ~1e-7..1e-6, - # so the previous 1e-5 absolute threshold mis-flagged every - # LN/bias tensor as "stalled" + # [v9] bf16-quantization-aware STALL threshold. v8 used + # 0.05 * lr * norm which, for a LayerNorm of dim=4096 + # (norm~64, bf16_eps per-element ~7.6e-6), yielded + # ~9.6e-6 — same order as the bf16 stochastic-rounding + # noise floor. That still mis-fired STALL 10/14 times + # in the 0428 v7 log even though mtp_loss was + # converging 646->145 (training clearly healthy). + # v9 formula: use bf16 round-trip error as the true + # floor, and ONLY warn after N consecutive sub-floor + # versions to avoid any transient data-shape blip. + # bf16 eps ~= 2^-7 (relative), so quantization error + # on |w| ~ 1 is ~7.8e-3 per element; for a tensor of + # numel elements the L2-norm of the quantization + # delta is ~sqrt(numel) * 7.8e-3 / 2 (average). But + # our metric is the delta between two norms, not + # the norm of the delta, and the norm itself is + # already rounded each time — so the per-sync + # observable floor is ~2^-17 * norm ~= 7.6e-6 * norm. try: _mtp_lr_cur = float( getattr(self, "_last_logged_mtp_lr", 3e-6) ) except Exception: _mtp_lr_cur = 3e-6 - # Expected per-step drift ~ lr * grad_norm; anything - # <5% of that is truly frozen. - _stall_thr = max(1e-9, 0.05 * _mtp_lr_cur * max(_cur, 1.0)) + _bf16_floor = 7.6e-6 * max(_cur, 1.0) + _expected_drift = max( + 1e-9, _mtp_lr_cur * max(_cur, 1.0) * 0.1 + ) + _stall_thr = max(_bf16_floor, _expected_drift) if _cur > 0 and abs(_d) < _stall_thr: _stall_tensors.append(_tn) self._mtp_sync_prev_norms[_tn] = _cur @@ -3091,19 +3119,50 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ", ".join(_fmt_parts), _drift_summary, ) - # Stall warning: require >=90% (was 50%) AND LR-adaptive - # threshold above. Previously bf16 noise floor mis-fired every - # step; the new combined criterion only fires if the MTP - # optimizer is truly dead. - if _deltas and len(_stall_tensors) >= 0.9 * len(_deltas): + # Windowed STALL: only warn if ALL of the last 3 + # consecutive versions flagged >=90% tensors stalled AND + # the *cumulative* drift over the window is below floor. + # This eliminates bf16 round-trip false alarms while + if not hasattr(self, "_mtp_stall_window"): + self._mtp_stall_window = [] # list of (version, pct, sum_d) + _this_pct = ( + len(_stall_tensors) / len(_deltas) if _deltas else 0.0 + ) + _this_sum_d = sum(_deltas) if _deltas else 0.0 + self._mtp_stall_window.append( + (meta.version, _this_pct, _this_sum_d) + ) + if len(self._mtp_stall_window) > 3: + self._mtp_stall_window.pop(0) + # Diagnostic: always log the window state to make + # subsequent triage self-evident. + _win_fmt = ",".join( + f"v{v}:{p*100:.0f}%/Σ={s:.1e}" + for v, p, s in self._mtp_stall_window + ) + _bf16_floor_total = 7.6e-6 * len(_deltas) * 64 # ~per-tensor floor * 64 + self.logger.info( + "[MTPSyncHealth-v9] STALL window (last %d syncs): " + "[%s] | bf16_floor_est=%.2e", + len(self._mtp_stall_window), _win_fmt, + _bf16_floor_total, + ) + if ( + len(self._mtp_stall_window) >= 3 + and all(p >= 0.9 for _, p, _ in self._mtp_stall_window) + and sum(s for _, _, s in self._mtp_stall_window) + < _bf16_floor_total * 2 + ): self.logger.warning( - "[MTPSyncHealth] MTP training STALL detected at version=%d: " - "%d/%d tensors drift<1e-5. " + "[MTPSyncHealth] MTP training STALL detected at " + "version=%d (3 consecutive sub-floor syncs): " + "%d/%d tensors drift Date: Tue, 28 Apr 2026 19:20:01 +0800 Subject: [PATCH 083/140] feat(engine): audit log --- areal/engine/megatron_engine.py | 111 ++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 7b99e03205..cb123c7e61 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1029,6 +1029,7 @@ def _collect_mtp_loss(self) -> dict[str, float]: name, _has_mg, _has_g, _flag_v, param.requires_grad, _mg_ptr) mtp_stats["mtp_grad_norm"] = mtp_g**0.5 + self._last_mtp_grad_norm = mtp_g**0.5 mtp_stats["non_mtp_grad_norm"] = non_mtp_g**0.5 mtp_stats["mtp_backward_scale"] = ( float(scale_str) if scale_str != "N/A" else 0.0) @@ -3147,6 +3148,41 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: len(self._mtp_stall_window), _win_fmt, _bf16_floor_total, ) + # One-line version->step audit trail. Makes + try: + _gn_audit = float( + getattr(self, "_last_mtp_grad_norm", 0.0) + ) + except Exception: + _gn_audit = 0.0 + try: + _step_audit = int( + getattr(self, "_global_step", 0) + ) + except Exception: + _step_audit = 0 + try: + _ntok_audit = int( + getattr(self, "_current_n_tokens", 0) + ) + except Exception: + _ntok_audit = 0 + try: + _nmb_audit = int( + getattr(self, "_current_num_microbatches", 0) + ) + except Exception: + _nmb_audit = 0 + self.logger.info( + "[MTPVersionAudit-v11] version=%d step=%d " + "mtp_grad_norm=%.4e num_mb=%d n_tokens=%d " + "max|Δ|=%.3e sum|Δ|=%.3e stalled_frac=%d/%d", + meta.version, _step_audit, _gn_audit, + _nmb_audit, _ntok_audit, + max(_deltas) if _deltas else 0.0, + sum(_deltas) if _deltas else 0.0, + len(_stall_tensors), len(_deltas), + ) if ( len(self._mtp_stall_window) >= 3 and all(p >= 0.9 for _, p, _ in self._mtp_stall_window) @@ -3165,6 +3201,30 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: sum(s for _, _, s in self._mtp_stall_window), ", ".join(_stall_tensors[:3]), ) + # the draft IS training and the "stall" is a bf16 + # quantization artefact at the broadcast boundary. + _last_gn = float(getattr(self, "_last_mtp_grad_norm", 0.0)) + # Additional liveness escape hatch + if _last_gn > 1e-4: + try: + _step_sup = int( + getattr(self, "_global_step", 0) + ) + except Exception: + _step_sup = 0 + self.logger.info( + "[MTPSyncHealth-v10] STALL candidate at " + "version=%d step=%d SUPPRESSED: " + "last mtp_grad_norm=%.4e > 1e-4 (draft IS " + "learning; bf16 quantization at broadcast " + "absorbs sub-ULP weight updates). Window: " + "%d/%d tensors 5: + self._mtp_tok_trend.pop(0) + if ( + len(self._mtp_tok_trend) >= 5 + and self._mtp_tok_trend[0][1] > 0 + and _n_tokens > 0 + ): + _prev_avg = sum( + t for _, t, _ in self._mtp_tok_trend[:-1] + ) / max(1, len(self._mtp_tok_trend) - 1) + _drop_pct = ( + 1.0 - _n_tokens / _prev_avg + ) if _prev_avg > 0 else 0.0 + _tok_trend_msg = ",".join( + f"s{s}:{t//1000}k/{n}mb" + for s, t, n in self._mtp_tok_trend + ) + if _drop_pct > 0.3: + self.logger.warning( + "[MTPDataTrend-v11] SEQUENCE-LENGTH " + "COLLAPSE: n_tokens dropped %.1f%% vs " + "5-step trailing avg (%.0f -> %d). " + "Trend: [%s]. Draft head will see " + "fewer tokens per update; accept_rate " + "regression is likely within 1-2 " + "versions. Mitigations: raise " + "mtp_loss_scaling_factor, enable reward " + "clipping, or widen rollout batch.", + _drop_pct * 100.0, _prev_avg, _n_tokens, + _tok_trend_msg, + ) + elif _drop_pct > 0.15: + self.logger.info( + "[MTPDataTrend-v11] mild token drop " + "%.1f%% (%.0f -> %d) over last 5 " + "steps. Trend: [%s].", + _drop_pct * 100.0, _prev_avg, _n_tokens, + _tok_trend_msg, + ) return loss * loss_scale From 23ebf0e843e3e6722fbef20eeb67365133603b82 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 21:59:06 +0800 Subject: [PATCH 084/140] fix(megatron_engine): mimo weight update --- areal/api/cli_args.py | 8 +- areal/engine/megatron_engine.py | 41 ++++++- areal/engine/megatron_utils/megatron.py | 152 ++++++++++++++++++++---- 3 files changed, 171 insertions(+), 30 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 677789e12d..4b29ef3feb 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1684,7 +1684,13 @@ class SGLangConfig: "help": "Attention mode for speculative decoding. E.g., 'full', 'sparse'." }, ) - enable_multi_layer_eagle: bool = False + enable_multi_layer_eagle: bool = field( + default=False, + metadata={ + "help": "Enable multi-layer EAGLE draft head (SGLang only). " + "Required when the draft model has more than one MTP layer." + }, + ) enable_draft_weights_cpu_backup: bool | None = field( default=None, metadata={ diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index cb123c7e61..0694da1ac9 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2406,12 +2406,36 @@ def _update_bucket_weights_from_distributed( _t_bc0 = _diag_time.time() handles = [] + _mtp_upcast_count = 0 for idx, (name, param) in enumerate(converted_named_tensors): + # MTP draft-head deltas are typically smaller than bf16 ULP + # (see [MTPSyncHealth] stall diagnostics). Upcast MTP tensors + # to fp32 on the trainer side before NCCL broadcast so the + # inference-side draft head sees the full precision update. + # The rollout side will downcast during load_weights. + send_tensor = param.data + if ( + (".enorm" in name or ".hnorm" in name or ".eh_proj" in name + or ".shared_head." in name or ".mtp_layers." in name) + and send_tensor.dtype == torch.bfloat16 + ): + send_tensor = send_tensor.float().contiguous() + # rebind so the receiver (whose dtype spec was already + # promoted in build_tensor_weight_update_request) matches. + converted_named_tensors[idx] = (name, send_tensor) + _mtp_upcast_count += 1 handles.append( dist.broadcast( - param.data, 0, group=self.weight_update_group, async_op=True + send_tensor, 0, group=self.weight_update_group, async_op=True ) ) + if _mtp_upcast_count > 0: + self.logger.info( + "[MTPBroadcastDtype] Upcast %d MTP tensors to fp32 for " + "NCCL broadcast (avoid bf16 ULP absorption of draft-head " + "weight deltas).", + _mtp_upcast_count, + ) self.logger.info( f"[DiagBucket] Enqueued {len(handles)} async broadcasts " f"in {_diag_time.time() - _t_bc0:.3f}s, calling handle.wait()..." @@ -2432,9 +2456,18 @@ def _update_bucket_weights_from_distributed( try: fut.result(timeout=30) except TimeoutError: - self.logger.warning( - "Callback response timed out, but NCCL broadcast " - "completed successfully. Continuing weight update." + # This was previously silently swallowed. Surface loudly: + # if the callback never finishes, the inference engine may + # have partially applied the broadcast, desyncing the draft. + self.logger.error( + "[MTPBroadcastTimeout] Callback response timed out after " + "30s while waiting for rollout side update_weights_from_" + "distributed to acknowledge. NCCL broadcast completed on " + "trainer side but the inference engine may NOT have " + "finished applying the weights. This CAN silently desync " + "MTP draft head and cause accept_rate decay. " + "n_tensors=%d, n_specs=%d.", + len(converted_named_tensors), len(param_specs), ) except Exception as e: self.logger.warning( diff --git a/areal/engine/megatron_utils/megatron.py b/areal/engine/megatron_utils/megatron.py index 62f8132be0..93c7b9c43b 100644 --- a/areal/engine/megatron_utils/megatron.py +++ b/areal/engine/megatron_utils/megatron.py @@ -1,4 +1,5 @@ import re +import logging import torch import torch.distributed as dist @@ -180,38 +181,59 @@ def _convert_mtp_layer_to_hf( param: Parameter | Tensor | FP8BlockwiseTensorHelper, tf_config: TransformerConfig, ) -> list[tuple[str, Tensor]] | None: - """Convert MCore MTP layer parameter names to HuggingFace format. - - MCore MTP layers follow the naming pattern: - module.module.decoder.mtp_layers.{layer_idx}.{submodule}.{param} - which maps to HF format: - model.mtp_layers.{layer_idx}.{submodule}.{param} - - Returns a list of (hf_name, param) tuples if the parameter is an MTP + """Generic MCore -> HF converter for a **MiMo-style** MTP layer. + + This function is kept for backwards compatibility for models whose HF + layout stores MTP tensors under ``model.mtp_layers.{i}.*`` (e.g. MiMo). + Models such as DeepSeek-V3 / GLM4-MoE (HF layout appends MTP as a + regular ``model.layers.{num_layers + i}.*`` with ``shared_head.norm``) + MUST provide a model-specific MTP converter instead; for those models + the DeepSeek/GLM-specific branch in ``convert_deepseekv3_to_hf`` / + ``convert_glm4moe_to_hf`` short-circuits before reaching this function. + + Handled MCore name patterns (both versions produced by different + megatron-core builds are accepted so this helper works regardless of + whether MTP lives under ``decoder.mtp_layers`` or top-level ``mtp.layers``): + module.module.decoder.mtp_layers.{idx}.{component} + module.module.mtp.layers.{idx}.{component} + + Returns a list of (hf_name, tensor) tuples if the parameter is an MTP parameter, or None if it is not. + + IMPORTANT: the previous implementation contained no-op + ``replace("enorm.weight", "enorm.weight")`` calls and a TODO-style + comment. It silently emitted mis-named tensors for every non-MiMo model + that routed through it (DeepSeek-V3 / GLM4-MoE / Bailing / Qwen3-MoE), + which in turn caused the SGLang rollout engine to silently skip MTP + weight updates (SpecDec accept_rate monotone decay). The behaviour is + now explicit: this helper ONLY emits names under + ``model.mtp_layers.{idx}.*`` and the non-MiMo callers no longer invoke + it on MTP params. """ - import re + logger = logging.getLogger(__name__) + # Accept both naming conventions produced by different mcore versions. mtp_match = re.match( r"module\.module\.decoder\.mtp_layers\.(\d+)\.(.+)", name ) + if mtp_match is None: + mtp_match = re.match( + r"module\.module\.mtp\.layers\.(\d+)\.(.+)", name + ) if mtp_match is None: return None layer_idx = int(mtp_match.group(1)) remainder = mtp_match.group(2) - # Map common MCore submodule names to HF names - hf_remainder = remainder - - # enorm / hnorm -> input_layernorm / post_attention_layernorm equivalent - hf_remainder = hf_remainder.replace("enorm.weight", "enorm.weight") - hf_remainder = hf_remainder.replace("hnorm.weight", "hnorm.weight") - - # Note: Some models (e.g., MiMo) may need column-half swap for eh_proj. - # This should be handled in model-specific conversion functions, not here. - # The generic MTP converter passes eh_proj through unchanged. - - hf_name = f"model.mtp_layers.{layer_idx}.{hf_remainder}" + # Keep the MCore remainder verbatim; MiMo HF layout expects + # ``model.mtp_layers.{idx}.{enorm|hnorm|eh_proj|final_layernorm}.weight`` + # and model-specific converters (e.g. ``_convert_mimo_mtp_param``) + # perform the name rewriting + eh_proj column-half swap themselves. + hf_name = f"model.mtp_layers.{layer_idx}.{remainder}" + logger.debug( + "[MTPConvertGeneric] mcore=%s -> hf=%s shape=%s", + name, hf_name, tuple(param.shape), + ) return [(hf_name, param)] # Adapted from slime @@ -472,6 +494,83 @@ def convert_qwen2_to_hf( # Adapted from slime +def _convert_deepseekv3_mtp_param( + tf_config: TransformerConfig, + name: str, + param: "Parameter | Tensor | FP8BlockwiseTensorHelper", +): + """DeepSeek-V3 MTP MCore -> HF converter. + + Mirrors the layout used by SGLang ``DeepseekV3ForCausalLMNextN`` and + matches slime / verl conventions: + mcore ``mtp.layers.{i}.enorm.weight`` -> ``model.layers.{N+i}.enorm.weight`` + mcore ``mtp.layers.{i}.hnorm.weight`` -> ``model.layers.{N+i}.hnorm.weight`` + mcore ``mtp.layers.{i}.eh_proj.weight`` -> ``model.layers.{N+i}.eh_proj.weight`` + mcore ``mtp.layers.{i}.final_layernorm.weight`` -> ``model.layers.{N+i}.shared_head.norm.weight`` + mcore ``mtp.layers.{i}.transformer_layer.<...>`` -> runs through the regular DSv3 + attention/MLP/MoE mappers by + rewriting the proxy name to + ``decoder.layers.{N+i}.<...>``. + + ``embed_tokens`` and ``shared_head.head`` are tied to the main model and + are NOT emitted from the MTP block (SGLang skips them during load). + + NOTE: unlike MiMo, there is NO column-half swap on ``eh_proj.weight`` + for DeepSeek-V3; slime and verl both pass it through unchanged. + """ + logger = logging.getLogger(__name__) + match = re.match(r"module\.module\.mtp\.layers\.(\d+)\.(.+)", name) + if match is None: + match = re.match( + r"module\.module\.decoder\.mtp_layers\.(\d+)\.(.+)", name + ) + if match is None: + return None + + mtp_local_idx, rest = match.groups() + mtp_local_idx = int(mtp_local_idx) + try: + num_layers = int(tf_config.num_layers) + except Exception as e: + raise ValueError( + f"[MTPConvertDSv3] cannot read num_layers from tf_config ({e}); " + "needed to compute HF MTP layer index." + ) + hf_layer_idx = num_layers + mtp_local_idx + + direct = { + "enorm.weight": f"model.layers.{hf_layer_idx}.enorm.weight", + "hnorm.weight": f"model.layers.{hf_layer_idx}.hnorm.weight", + "eh_proj.weight": f"model.layers.{hf_layer_idx}.eh_proj.weight", + "final_layernorm.weight": ( + f"model.layers.{hf_layer_idx}.shared_head.norm.weight" + ), + } + if rest in direct: + logger.info( + "[MTPConvertDSv3] mcore=%s -> hf=%s (direct) shape=%s", + name, direct[rest], tuple(param.shape), + ) + return [(direct[rest], param)] + + # transformer_layer.* ==> run the regular DSv3 mapper on a proxy name + # that pretends this MTP block is ``decoder.layers.{num_layers+i}``. + if not rest.startswith("transformer_layer."): + raise ValueError( + f"[MTPConvertDSv3] unsupported MTP component {rest!r} in {name!r}" + ) + inner = rest[len("transformer_layer."):] + proxy_name = f"module.module.decoder.layers.{hf_layer_idx}.{inner}" + logger.info( + "[MTPConvertDSv3] delegating transformer_layer: mcore=%s via proxy=%s", + name, proxy_name, + ) + # Call the main DSv3 converter with the proxy name. This reuses all of + # the attention/MLP/MoE mapping logic and yields the correct + # ``model.layers.{hf_layer_idx}.*`` HF keys in one shot. + return convert_deepseekv3_to_hf(tf_config, proxy_name, param) + + def convert_deepseekv3_to_hf( tf_config: TransformerConfig, name: str, @@ -484,10 +583,13 @@ def convert_deepseekv3_to_hf( if name == "module.module.decoder.final_layernorm.weight": return [("model.norm.weight", param)] - # Check for MTP layer parameters - mtp_result = _convert_mtp_layer_to_hf(name, param, tf_config) - if mtp_result is not None: - return mtp_result + # MTP layer parameters are routed through a DSv3-specific converter + # that emits the correct HF layout for SGLang + # (model.layers.{num_layers+i}.{enorm,hnorm,eh_proj,shared_head.norm,...}). + if ".mtp." in name or ".mtp_layers." in name: + mtp_result = _convert_deepseekv3_mtp_param(tf_config, name, param) + if mtp_result is not None: + return mtp_result try: head_dim = ( From b0d9363f2009ffc98a148b0fd83d02d0ee84da85 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 11:30:22 +0800 Subject: [PATCH 085/140] fix: scale up mtp_lr_scale --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 6b09adf474..983a665c53 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -60,7 +60,7 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 - mtp_lr_scale: 1.0 + mtp_lr_scale: 1.0e4 eps_clip: 0.4 temperature: ${gconfig.temperature} reward_scaling: 10.0 From 9c29945b4400cd716d8d737ff1874af893aa825c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 14:12:42 +0800 Subject: [PATCH 086/140] fix(megatron_engine): add log --- areal/engine/megatron_engine.py | 83 +++++++++++++++++++-- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 2 files changed, 76 insertions(+), 9 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 0694da1ac9..832eb4cae8 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -2317,6 +2317,33 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: config_overrides=_mtp_lr_config_overrides, ) + # [MTPOptim-v12] Dump param_groups to verify ParamKey override + # actually installed. Megatron 0.16 ParamKey does NOT attach a `name` + # field to param_groups, so downstream identification must use + # `max_lr` fingerprint instead of name match. + try: + _base_max_lr = float(self.optimizer_config.lr) + for _idx, _pg in enumerate( + getattr(self.optimizer, "param_groups", []) or [] + ): + _n_params = len(_pg.get("params", []) or []) + _mxlr = _pg.get("max_lr", None) + _mnlr = _pg.get("min_lr", None) + _is_mtp = ( + _mxlr is not None + and abs(float(_mxlr) - _base_max_lr) > 1e-12 + ) + self.logger.info( + "[MTPOptim-v12] param_group[%d]: n_params=%d max_lr=%s " + "min_lr=%s is_mtp_group=%s", + _idx, _n_params, str(_mxlr), str(_mnlr), + str(_is_mtp), + ) + except Exception as _e: + self.logger.warning( + "[MTPOptim-v12] param_groups dump failed: %s", _e + ) + warmup_steps_proportion = self.optimizer_config.warmup_steps_proportion warmup_steps = int(warmup_steps_proportion * ft_spec.total_train_steps) lr_scheduler = OptimizerParamScheduler( @@ -3766,15 +3793,37 @@ def _compute_logprobs_and_loss( ) if _mtp_loss_for_this_mb is not None and abs(loss_scale) > 0: - # [v8] Refresh cached MTP LR from optimizer param_groups so the - # DoubleScale log and SyncHealth STALL threshold can use the - # realised LR (not a hardcoded default). + # Refresh cached MTP LR from optimizer param_groups using + # max_lr fingerprint (ParamKey override in megatron-core 0.16 + # does NOT propagate the ParamKey.name into the param_group + # dict, so the previous name-based match always missed the MTP + # group and left _last_logged_mtp_lr at its default 3e-6, making + # the DoubleScale log severely misleading). try: - for _pg in getattr(self.optimizer, "param_groups", []): - _nm = str(_pg.get("name", "")) - if "mtp" in _nm.lower(): - self._last_logged_mtp_lr = float(_pg.get("lr", 3e-6)) - break + _pgs = getattr(self.optimizer, "param_groups", []) or [] + if len(_pgs) > 1: + _base_mx = _pgs[0].get("max_lr", None) + for _pg in _pgs: + _mxlr = _pg.get("max_lr", None) + if ( + _mxlr is not None + and _base_mx is not None + and abs(float(_mxlr) - float(_base_mx)) > 1e-12 + ): + self._last_logged_mtp_lr = float( + _pg.get("lr", _pg.get("max_lr", 3e-6)) + ) + break + else: + # Single-group case or equal max_lr -> MTP shares + # the base lr. + self._last_logged_mtp_lr = float( + _pgs[0].get("lr", 3e-6) + ) + elif len(_pgs) == 1: + self._last_logged_mtp_lr = float( + _pgs[0].get("lr", 3e-6) + ) except Exception: pass # Match Megatron-native MTPLossAutoScaler: @@ -3802,6 +3851,24 @@ def _compute_logprobs_and_loss( except Exception: _mtp_lr_dbg = 3e-6 _eff_step_mag = _eff_per_mb * _mtp_lr_dbg + # [MTPSanity-v12] Detect explosive per-step update. bf16 + # dynamic range for |W|~0.4 places 1 ULP near 3e-3; any + # per-step update >= 1e-2 is already tens of ULPs and + # almost always means the draft head is diverging. Emit + # a prominent warning rather than letting accept_rate + # silently collapse. + try: + if abs(_eff_step_mag) >= 1e-2: + self.logger.warning( + "[MTPSanity-v12] per-step MTP update " + "magnitude %.3e >= 1e-2 (>= ~3x bf16 ULP " + "for |W|~0.4); draft head divergence is " + "likely. Reduce mtp_lr_scale or " + "mtp_loss_scaling.", + _eff_step_mag, + ) + except Exception: + pass self.logger.info( "[MTPFix-DoubleScale-v6] Inverse-(loss_scale*num_mb) " "applied: loss_scale=%.6f, num_mb=%d, inv=%.4f, " diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 983a665c53..6b09adf474 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -60,7 +60,7 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 - mtp_lr_scale: 1.0e4 + mtp_lr_scale: 1.0 eps_clip: 0.4 temperature: ${gconfig.temperature} reward_scaling: 10.0 From 410fb90483c785e0b3a0bfacce37af5733c6a917 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 30 Apr 2026 01:59:59 +0800 Subject: [PATCH 087/140] feat(megatron_engine): fp32 weight update --- areal/engine/megatron_engine.py | 205 +++++++++++++++++++++++++++++++- 1 file changed, 203 insertions(+), 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 832eb4cae8..cefe44492a 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -188,6 +188,34 @@ def __init__(self, config: TrainEngineConfig): self._mtp_layers_verified: bool = False self._mtp_tensor_update_warned: bool = False if self.enable_mtp_training: + # [MTPVersionBanner-v16] + v17 tag: make it trivial to + # verify which patch revision is running in a given log. + try: + import os as _os_banner + _banner_tags = [ + "v6:DoubleScaleInv", + "v9:bf16StallDiag", + "v11:VersionAudit", + "v12:OptimDump+Sanity", + "v14:LRScaleGuard+WeightDeltaGuard", + "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", + "v17:MTPNativeAutoScaler+ConsumerBypass" + "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", + ] + _banner_flags = { + "AREAL_MTP_FP32_BROADCAST": + _os_banner.environ.get( + "AREAL_MTP_FP32_BROADCAST", "0"), + "AREAL_MTP_NATIVE_AUTOSCALER": + _os_banner.environ.get( + "AREAL_MTP_NATIVE_AUTOSCALER", "0"), + } + self.logger.info( + "[MTPVersionBanner] tags=%s flags=%s", + ",".join(_banner_tags), _banner_flags, + ) + except Exception: + pass self.logger.info( f"[MTPTrain] MTP online training ENABLED: " f"num_layers={self.mtp_num_layers}, " @@ -1456,6 +1484,92 @@ def _patched_postprocess( else: _mtp_loss_to_store = mtp_loss_scale * mtp_loss / num_tokens _engine_ref._mtp_loss_for_backward.append(_mtp_loss_to_store) + + # --- BEGIN --- + # Reproduce Megatron-native behaviour: + # hidden_states = MTPLossAutoScaler.apply( + # hidden_states, + # mtp_loss_scale * mtp_loss [/ num_tokens], + # ) + # where MTPLossAutoScaler.backward() returns + # (grad_output, ones_like(mtp_loss) * + # main_loss_backward_scale). Combined with + # set_loss_scale(1/num_microbatches) this + # injects a per-token * per-vocab gradient + # of magnitude ~ mtp_loss_scale straight into + # the autograd graph, bypassing the scalar + # FIFO + DoubleScale-v6 inverse path. + # + # Gated so the legacy behaviour remains + # bit-exact by default. Enable with + # AREAL_MTP_NATIVE_AUTOSCALER=1 + try: + import os as _os_v17 + _v17_on = ( + _os_v17.environ.get( + "AREAL_MTP_NATIVE_AUTOSCALER", + "0", + ) == "1" + ) + except Exception: + _v17_on = False + if _v17_on: + try: + from megatron.core.transformer.multi_token_prediction import ( + MTPLossAutoScaler as _MTPLossAutoScaler_v17, + ) + _num_mb_v17 = int(getattr( + _engine_ref, + "_current_num_microbatches", + 1, + ) or 1) + if _num_mb_v17 <= 0: + _num_mb_v17 = 1 + import torch as _torch_v17 + # schedules.py sets + # main_loss_backward_scale = + # loss_scale / num_microbatches; + # AReaL's consumer already folds + # loss_scale via the outer + # loss * loss_scale contract, + # so only 1/num_mb is needed here. + _MTPLossAutoScaler_v17.set_loss_scale( + _torch_v17.tensor( + 1.0 / float(_num_mb_v17) + ) + ) + hidden_states = ( + _MTPLossAutoScaler_v17.apply( + hidden_states, + _mtp_loss_to_store, + ) + ) + _engine_ref._v17_native_active = True + if _mtp_diag_mb_counter[0] == 0: + _logger.info( + "[MTPNativeAutoScaler-v17] " + "apply() injected: " + "num_mb=%d, " + "main_loss_backward_scale=%.6e, " + "hidden_states.shape=%s, " + "hidden_states.rg=%s", + _num_mb_v17, + 1.0 / float(_num_mb_v17), + list(hidden_states.shape), + hidden_states.requires_grad, + ) + except Exception as _e_v17: + _engine_ref._v17_native_active = False + _logger.warning( + "[MTPNativeAutoScaler-v17] " + "apply() failed, falling back " + "to legacy FIFO+DoubleScale " + "path: %s", + _e_v17, + ) + else: + _engine_ref._v17_native_active = False + # --- [MTPNativeAutoScaler-v17] END --- # [v5-F4] Cap FIFO to avoid unbounded growth on producer/consumer drift. _fifo_len = len(_engine_ref._mtp_loss_for_backward) if _fifo_len > 32: @@ -3053,10 +3167,58 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: fp8_direct_convert=self.fp8_direct_convert, ) ) + # Upcast MTP-draft + # tensors to fp32 before serialization/broadcast so + # sub-bf16-ULP weight deltas are not rounded away + # on the wire. Complements upstream's NCCL-path + # [MTPBroadcastDtype] upcast (which only covers + # the distributed-weight-update path, not the + # MTPSerialize/update_weights_from_tensor path). + # Gated on AREAL_MTP_FP32_BROADCAST=1. + try: + import os as _os_v16 + _v16_on = ( + _os_v16.environ.get( + "AREAL_MTP_FP32_BROADCAST", "0", + ) == "1" + ) + except Exception: + _v16_on = False + if _v16_on: + import torch as _torch_v16 + _upcasted = 0 + for _i in range(_prev_count, len(mtp_hf_tensors)): + _nm_v16, _tn_v16 = mtp_hf_tensors[_i] + if _tn_v16.dtype == _torch_v16.bfloat16: + mtp_hf_tensors[_i] = ( + _nm_v16, + _tn_v16.float().contiguous(), + ) + _upcasted += 1 + if _upcasted > 0: + self.logger.info( + "[MTPBf16UpcastBroadcast-v16] Upcast %d MTP " + "tensors bf16->fp32 at MTPSerialize path " + "(name=%s).", + _upcasted, name, + ) # Diagnostic: log each converted MTP tensor with value # statistics for post-mortem debugging of weight corruption. for _hf_name, _hf_tensor in mtp_hf_tensors[_prev_count:]: _abs = _hf_tensor.float().abs() + # [MTPWeightDeltaGuard-v14] Flag all-zero MTP + # tensors explicitly so draft-head stall is + # surfaced independently of MTPWeightDiag. + try: + if float(_abs.max().item()) == 0.0: + self.logger.warning( + "[MTPWeightDeltaGuard-v14] MTP " + "tensor %s (hf=%s) has abs_max==0; " + "draft head is stalled this step.", + name, _hf_name, + ) + except Exception: + pass self.logger.info( f"[MTPWeightDiag] convert_to_hf: " f"megatron={name} -> hf={_hf_name}, " @@ -3774,7 +3936,17 @@ def _compute_logprobs_and_loss( _mtp_ema_decay * _ema_val + (1 - _mtp_ema_decay) * _raw_val ) - loss = loss + _mtp_contribution + if not bool(getattr(self, "_v17_native_active", False)): + loss = loss + _mtp_contribution + else: + # [MTPNativeConsumerBypass-v17] Native MTPLossAutoScaler + # already injected the gradient via autograd; adding + # _mtp_contribution scalar here would double-count. + if self._mtp_loss_total_count == 0: + self.logger.info( + "[MTPNativeConsumerBypass-v17] Skipping scalar " + "loss+=_mtp_contribution; autograd path active." + ) _n = self._mtp_loss_total_count if _n <= 4 or _n % 100 == 0: self.logger.info( @@ -3792,7 +3964,11 @@ def _compute_logprobs_and_loss( loss_scale, ) - if _mtp_loss_for_this_mb is not None and abs(loss_scale) > 0: + if ( + _mtp_loss_for_this_mb is not None + and abs(loss_scale) > 0 + and not bool(getattr(self, "_v17_native_active", False)) + ): # Refresh cached MTP LR from optimizer param_groups using # max_lr fingerprint (ParamKey override in megatron-core 0.16 # does NOT propagate the ParamKey.name into the param_group @@ -3824,6 +4000,31 @@ def _compute_logprobs_and_loss( self._last_logged_mtp_lr = float( _pgs[0].get("lr", 3e-6) ) + # [MTPLRScaleGuard-v14] detect obviously-wrong MTP lr. + try: + _mtp_lr_g = float( + getattr(self, "_last_logged_mtp_lr", 0.0) + ) + _base_lr_g = None + if _pgs: + _base_lr_g = float(_pgs[0].get("lr", 0.0)) + if ( + _base_lr_g is not None + and _base_lr_g > 0 + and _mtp_lr_g > 0 + and _mtp_lr_g >= 10.0 * _base_lr_g + and (self._mtp_loss_total_count <= 4 + or self._mtp_loss_total_count % 100 == 0) + ): + self.logger.warning( + "[MTPLRScaleGuard-v14] MTP lr %.3e is " + ">=10x base lr %.3e; this is almost " + "certainly a mis-scaled mtp_lr_scale " + "and will destabilise the draft head.", + _mtp_lr_g, _base_lr_g, + ) + except Exception: + pass except Exception: pass # Match Megatron-native MTPLossAutoScaler: From 8c3f60d72d9277f1c016f4e84d77eb3bf01bd0fe Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 2 May 2026 19:51:22 +0800 Subject: [PATCH 088/140] feat(engine): add full stage log --- areal/engine/megatron_engine.py | 375 +++++++++++++++++++++++++++++++- areal/engine/sglang_remote.py | 24 +- 2 files changed, 393 insertions(+), 6 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index cefe44492a..4588d5dd58 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -201,6 +201,7 @@ def __init__(self, config: TrainEngineConfig): "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", + "v20:SpecDecDiag(D01-D14 full pipeline instrumentation)", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": @@ -214,6 +215,60 @@ def __init__(self, config: TrainEngineConfig): "[MTPVersionBanner] tags=%s flags=%s", ",".join(_banner_tags), _banner_flags, ) + try: + import torch as _t_d01 + _dtype_d01 = str(getattr(self, "dtype", "n/a")) + _opt_cfg = getattr(self, "optimizer_config", None) + _mc_cfg = getattr(self, "mcore_config", None) + self.logger.info( + "[SpecDecDiag-v20 D01] EngineInit: " + "mtp_num_layers=%s mtp_loss_scaling_factor=%s " + "mtp_detach_heads=%s enable_mtp_training=%s " + "dtype=%s torch_version=%s", + getattr(self, "mtp_num_layers", None), + getattr(self, "mtp_loss_scaling_factor", None), + getattr(self, "mtp_detach_heads", None), + getattr(self, "enable_mtp_training", None), + _dtype_d01, _t_d01.__version__, + ) + if _opt_cfg is not None: + self.logger.info( + "[SpecDecDiag-v20 D01] EngineInit optimizer_cfg: " + "type=%s lr=%s weight_decay=%s beta1=%s beta2=%s " + "eps=%s mtp_lr_scale=%s gradient_clipping=%s " + "lr_scheduler_type=%s", + getattr(_opt_cfg, "type", None), + getattr(_opt_cfg, "lr", None), + getattr(_opt_cfg, "weight_decay", None), + getattr(_opt_cfg, "beta1", None), + getattr(_opt_cfg, "beta2", None), + getattr(_opt_cfg, "eps", None), + getattr(_opt_cfg, "mtp_lr_scale", None), + getattr(_opt_cfg, "gradient_clipping", None), + getattr(_opt_cfg, "lr_scheduler_type", None), + ) + if _mc_cfg is not None: + self.logger.info( + "[SpecDecDiag-v20 D01] EngineInit mcore_cfg: " + "use_precision_aware_optimizer=%s " + "exp_avg_dtype=%s exp_avg_sq_dtype=%s " + "use_distributed_optimizer=%s " + "overlap_param_gather_with_optimizer_step=%s", + getattr(_mc_cfg, + "use_precision_aware_optimizer", None), + getattr(_mc_cfg, "exp_avg_dtype", None), + getattr(_mc_cfg, "exp_avg_sq_dtype", None), + getattr(_mc_cfg, + "use_distributed_optimizer", None), + getattr(_mc_cfg, + "overlap_param_gather_with_optimizer_step", + None), + ) + except Exception as _e_d01: + self.logger.warning( + "[SpecDecDiag-v20 D01] static dump failed: %s", + _e_d01, + ) except Exception: pass self.logger.info( @@ -1011,6 +1066,51 @@ def _collect_mtp_loss(self) -> dict[str, float]: "[MTPGradDiag] Per-MTP-param gradient norms:\n" + "\n".join(mtp_param_details) ) + try: + _d09_step = getattr(self, "_global_step", 0) + if _d09_step <= 5 or _d09_step % 20 == 0: + _d09_rows = [] + import torch as _t_d09 + for _m in self.model: + for _n, _p in _m.named_parameters(): + if ".mtp." not in _n: + continue + _g = (_p.main_grad + if hasattr(_p, "main_grad") + and _p.main_grad is not None + else _p.grad) + if _g is None: + continue + _gf = _g.data.float() + _pf = _p.data.float() + _d09_rows.append( + "%s: dtype=%s |W|_max=%.3e " + "|W|_mean=%.3e " + "|g|_max=%.3e |g|_mean=%.3e " + "g.sum=%.3e g.finite=%s" % ( + _n, str(_p.dtype), + _pf.abs().max().item(), + _pf.abs().mean().item(), + _gf.abs().max().item(), + _gf.abs().mean().item(), + _gf.sum().item(), + bool(_t_d09.isfinite(_gf) + .all().item()), + ) + ) + if _d09_rows: + self.logger.info( + "[SpecDecDiag-v20 D09] step=%d " + "per-MTP-param grad+weight " + "snapshot:\n%s", + _d09_step, + "\n".join(_d09_rows), + ) + except Exception as _e_d09: + self.logger.warning( + "[SpecDecDiag-v20 D09] failed: %s", + _e_d09, + ) # Additional diagnostic: check if any MTP param # has .grad (not main_grad) with nonzero value, # which would indicate gradient accumulation fusion @@ -1092,8 +1192,86 @@ def _collect_mtp_loss(self) -> dict[str, float]: return mtp_stats def optimizer_step(self): + # [SpecDecDiag-v20 D10] pre-step MTP weight snapshot. + _d10_pre = {} + try: + _d10_step = int(getattr(self, "_global_step", 0) or 0) + _d10_sample = (_d10_step <= 20) or (_d10_step % 10 == 0) + if (self.enable_mtp_training and _d10_sample + and getattr(self, "model", None) is not None): + for _mod in self.model: + for _n, _p in _mod.named_parameters(): + if ".mtp." not in _n: + continue + try: + _d10_pre[_n] = _p.detach().clone() + except Exception: + pass + except Exception as _e_d10: + self.logger.warning( + "[SpecDecDiag-v20 D10] snapshot failed: %s", _e_d10, + ) + with trace_scope("megatron_engine.step"): update_successful, grad_norm, _ = self.optimizer.step() + + # [SpecDecDiag-v20 D11] post-step |deltaW| per MTP tensor. + try: + import torch as _t_d11 + if _d10_pre: + _step_d11 = int(getattr(self, "_global_step", 0) or 0) + _rows = [] + _floor_est = 7.78e-3 + _n_total = 0 + _n_stalled = 0 + _max_delta_global = 0.0 + for _mod in self.model: + for _n, _p in _mod.named_parameters(): + if _n not in _d10_pre: + continue + _pre = _d10_pre[_n] + try: + _delta = (_p.detach() - _pre).float().abs() + _max = float(_delta.max().item()) + _mean = float(_delta.mean().item()) + _norm = float(_delta.norm().item()) + _w_abs_max = float(_p.detach().float() + .abs().max().item()) + _n_total += 1 + _stalled = _max == 0.0 + if _stalled: + _n_stalled += 1 + if _max > _max_delta_global: + _max_delta_global = _max + if len(_rows) < 8 or _stalled: + _rows.append( + "%s: |dW|_max=%.3e mean=%.3e " + "norm=%.3e |W|_max=%.3e %s" % ( + _n, _max, _mean, _norm, + _w_abs_max, + "STALLED" if _stalled else "", + ) + ) + except Exception: + pass + self.logger.info( + "[SpecDecDiag-v20 D11] PostOpt step=%d " + "total=%d stalled=%d max|dW|_global=%.3e " + "bf16_ulp_floor_est=%.3e", + _step_d11, _n_total, _n_stalled, + _max_delta_global, _floor_est, + ) + if _rows: + self.logger.info( + "[SpecDecDiag-v20 D11] per-tensor (step=%d):\n%s", + _step_d11, "\n".join(_rows), + ) + _d10_pre.clear() + except Exception as _e_d11: + self.logger.warning( + "[SpecDecDiag-v20 D11] compare failed: %s", _e_d11, + ) + current_lr = self.optimizer.param_groups[0]["lr"] # Log MTP lr if using separate param group @@ -1448,6 +1626,56 @@ def _patched_postprocess( mtp_labels, mtp_logits ) mtp_loss = loss_mask * mtp_loss + try: + _d05_step = getattr( + _engine_ref, "_global_step", 0) + _d05_mb = _mtp_diag_mb_counter[0] + _d05_gate = (_d05_mb == 0 and + (_d05_step <= 5 + or _d05_step % 50 == 0)) + if _d05_gate: + import torch as _t_d05 + _hs_f = hidden_states.detach().float() + _lm_f = loss_mask.detach().float() + _logger.info( + "[SpecDecDiag-v20 D05] " + "MTPLayer#%d step=%d " + "hidden_states: shape=%s " + "dtype=%s rg=%s " + "abs_mean=%.3e abs_max=%.3e " + "finite=%s", + mtp_layer_number, _d05_step, + list(hidden_states.shape), + str(hidden_states.dtype), + hidden_states.requires_grad, + _hs_f.abs().mean().item(), + _hs_f.abs().max().item(), + bool(_t_d05.isfinite(_hs_f) + .all().item()), + ) + _logger.info( + "[SpecDecDiag-v20 D05] " + "MTPLayer#%d step=%d " + "loss_mask: shape=%s " + "num_tokens=%s sum=%.1f " + "mtp_loss_raw: abs_mean=%.3e " + "abs_max=%.3e sum=%.6f", + mtp_layer_number, _d05_step, + list(loss_mask.shape), + num_tokens, + _lm_f.sum().item(), + mtp_loss.detach().float() + .abs().mean().item(), + mtp_loss.detach().float() + .abs().max().item(), + mtp_loss.detach().float() + .sum().item(), + ) + except Exception as _e_d05: + _logger.warning( + "[SpecDecDiag-v20 D05] failed: %s", + _e_d05, + ) # [v5-F1c] Gate MB#0 mtp_loss diag to first 3 steps + every 100. _gs_ml = getattr(_engine_ref, '_global_step', 0) if (_mtp_diag_mb_counter[0] == 0 @@ -1538,12 +1766,90 @@ def _patched_postprocess( 1.0 / float(_num_mb_v17) ) ) + try: + _d06_step = getattr( + _engine_ref, + "_global_step", 0) + if (_mtp_diag_mb_counter[0] == 0 + and (_d06_step <= 5 + or _d06_step % 50 == 0)): + _logger.info( + "[SpecDecDiag-v20 " + "D06] " + "step=%d mtp_layer=%d " + "mtp_loss_scale=%.6e " + "calculate_per_token_" + "loss=%s " + "num_tokens=%s " + "num_mb=%d " + "mtp_loss_to_store:" + " shape=%s rg=%s " + "sum=%.6e abs_max=%.3e", + _d06_step, + mtp_layer_number, + float(mtp_loss_scale), + self_model.config + .calculate_per_token_loss, + num_tokens, + _num_mb_v17, + list(_mtp_loss_to_store + .shape), + _mtp_loss_to_store + .requires_grad, + _mtp_loss_to_store + .detach().float() + .sum().item(), + _mtp_loss_to_store + .detach().float() + .abs().max().item(), + ) + except Exception as _e_d06: + _logger.warning( + "[SpecDecDiag-v20 D06] " + "failed: %s", _e_d06, + ) hidden_states = ( _MTPLossAutoScaler_v17.apply( hidden_states, _mtp_loss_to_store, ) ) + try: + _d07_bs = ( + _MTPLossAutoScaler_v17 + .main_loss_backward_scale + ) + _d07_bs_v = ( + float(_d07_bs.item()) + if hasattr(_d07_bs, "item") + else float(_d07_bs) + ) + if (_mtp_diag_mb_counter[0] == 0 + and (_d06_step <= 5 + or _d06_step % 50 + == 0)): + _logger.info( + "[SpecDecDiag-v20 " + "D07] step=%d " + "mtp_layer=%d " + "post-apply " + "main_loss_backward_" + "scale=%.6e " + "hs.grad_fn=%s", + _d06_step, + mtp_layer_number, + _d07_bs_v, + type(hidden_states + .grad_fn).__name__ + if hidden_states + .grad_fn + else "None", + ) + except Exception as _e_d07: + _logger.warning( + "[SpecDecDiag-v20 D07] " + "failed: %s", _e_d07, + ) _engine_ref._v17_native_active = True if _mtp_diag_mb_counter[0] == 0: _logger.info( @@ -1600,14 +1906,30 @@ def _patched_postprocess( and hidden_states.requires_grad and _should_log_bwd): def _mtp_backward_hook(grad, _lg=_logger, _gs=_gs_v5): - # Inner hook fires once per backward; log only on gated steps. + import torch as _t_d08 + _g_f = grad.float() _lg.info( "[MTPBwdDiag] AutoScaler backward FIRED (step=%d): " "grad.shape=%s, grad.norm=%.8f, " "grad.abs_max=%.8f", _gs, list(grad.shape), - grad.float().norm().item(), - grad.float().abs().max().item()) + _g_f.norm().item(), + _g_f.abs().max().item()) + _lg.info( + "[SpecDecDiag-v20 D08] " + "hs-bwd step=%d grad.abs_mean=%.3e " + "grad.mean=%.3e grad.std=%.3e " + "grad.nonzero_frac=%.3f " + "grad.finite=%s dtype=%s", + _gs, + _g_f.abs().mean().item(), + _g_f.mean().item(), + _g_f.std().item(), + (_g_f != 0).float().mean().item(), + bool(_t_d08.isfinite(_g_f) + .all().item()), + str(grad.dtype), + ) hidden_states.register_hook(_mtp_backward_hook) _logger.info( "[MTPFwdDiag] MB#0 Registered backward hook on " @@ -2052,6 +2374,36 @@ def process_output( self._global_step = 0 self._global_step += 1 + # [SpecDecDiag-v20 D04] per-step summary before fwd/bwd. + try: + _d04_nmb = int(self._current_num_microbatches) + _d04_ntok = int(getattr(self, "_current_n_tokens", 0) or 0) + _d04_pgs = getattr(self.optimizer, "param_groups", []) or [] + _d04_base_lr = float(_d04_pgs[0].get("lr", 0.0)) if _d04_pgs else 0.0 + _d04_base_max = float(_d04_pgs[0].get("max_lr", 0.0)) if _d04_pgs else 0.0 + _d04_mtp_lr = None + if self.enable_mtp_training and len(_d04_pgs) > 1: + for _pg in _d04_pgs: + if (_pg.get("max_lr", None) is not None + and abs(float(_pg.get("max_lr")) + - _d04_base_max) > 1e-12): + _d04_mtp_lr = float(_pg.get("lr", 0.0)) + break + self.logger.info( + "[SpecDecDiag-v20 D04] TrainStepEnter step=%d num_mb=%d " + "n_tokens=%d base_lr=%.3e base_max_lr=%.3e " + "mtp_lr=%s n_param_groups=%d loss_multiplier=%.3e", + self._global_step, _d04_nmb, _d04_ntok, + _d04_base_lr, _d04_base_max, + ("%.3e" % _d04_mtp_lr) if _d04_mtp_lr is not None else "base", + len(_d04_pgs), float(loss_multiplier), + ) + except Exception as _e_d04: + self.logger.warning( + "[SpecDecDiag-v20 D04] TrainStepEnter log failed: %s", + _e_d04, + ) + self.forward_backward_batch(mb_list, process_output, forward_only=False) DeviceRuntimeInfo.get_current().log("train_batch after forward_backward") @@ -3187,14 +3539,31 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if _v16_on: import torch as _torch_v16 _upcasted = 0 + _d12_sample = None for _i in range(_prev_count, len(mtp_hf_tensors)): _nm_v16, _tn_v16 = mtp_hf_tensors[_i] + if _d12_sample is None: + try: + _d12_sample = ( + _nm_v16, str(_tn_v16.dtype), + float(_tn_v16.float().abs() + .max().item()), + ) + except Exception: + _d12_sample = (_nm_v16, "n/a", 0.0) if _tn_v16.dtype == _torch_v16.bfloat16: mtp_hf_tensors[_i] = ( _nm_v16, _tn_v16.float().contiguous(), ) _upcasted += 1 + if _d12_sample is not None: + self.logger.info( + "[SpecDecDiag-v20 D12] pre-upcast-sample " + "hf=%s dtype=%s |W|_max=%.3e upcasted=%d", + _d12_sample[0], _d12_sample[1], + _d12_sample[2], _upcasted, + ) if _upcasted > 0: self.logger.info( "[MTPBf16UpcastBroadcast-v16] Upcast %d MTP " diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index b64e05efda..037b23cc6c 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -115,10 +115,28 @@ def parse_generation_response( if spec_accept_token_num is not None and spec_draft_token_num is not None: if spec_draft_token_num > 0: accept_rate = spec_accept_token_num / spec_draft_token_num - logger.debug( - f"[SpecDec] SGLang response: accept={spec_accept_token_num}, " - f"draft={spec_draft_token_num}, rate={accept_rate:.4f}" + import os as _os_d03 + _d03_to_debug = ( + _os_d03.environ.get( + "AREAL_SPECDEC_D03_DEBUG", "0") == "1" ) + if _d03_to_debug: + logger.debug( + f"[SpecDecDiag-v20 D03] SGLang response: " + f"accept={spec_accept_token_num}, " + f"draft={spec_draft_token_num}, " + f"rate={accept_rate:.4f}" + ) + else: + logger.info( + f"[SpecDecDiag-v20 D03] SGLang response: " + f"accept={spec_accept_token_num}, " + f"draft={spec_draft_token_num}, " + f"rate={accept_rate:.4f} " + f"prompt_tokens={meta_info.get('prompt_tokens', 'n/a')} " + f"completion_tokens=" + f"{meta_info.get('completion_tokens', 'n/a')}" + ) if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): return HttpGenerationResult( output_tokens=[], From 56e5e08c8915bcdafcdd1dd78dae3d1bdfb9c4f7 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 00:07:27 +0800 Subject: [PATCH 089/140] fix(engine): fix mtp --- areal/engine/megatron_engine.py | 164 ++++++++++++++++++++++++++++++-- 1 file changed, 158 insertions(+), 6 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 4588d5dd58..9b8da551ab 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -202,14 +202,19 @@ def __init__(self, config: TrainEngineConfig): "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", "v20:SpecDecDiag(D01-D14 full pipeline instrumentation)", + "v21:MTPFp32MasterRead+DefaultsOn" + "(AREAL_MTP_FP32_MASTER_READ,AREAL_MTP_FP32_BROADCAST=1)", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": _os_banner.environ.get( - "AREAL_MTP_FP32_BROADCAST", "0"), + "AREAL_MTP_FP32_BROADCAST", "1"), "AREAL_MTP_NATIVE_AUTOSCALER": _os_banner.environ.get( "AREAL_MTP_NATIVE_AUTOSCALER", "0"), + "AREAL_MTP_FP32_MASTER_READ": + _os_banner.environ.get( + "AREAL_MTP_FP32_MASTER_READ", "1"), } self.logger.info( "[MTPVersionBanner] tags=%s flags=%s", @@ -3506,7 +3511,125 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: mtp_param_count += 1 mtp_param_bytes += param.numel() * param.element_size() if _collect_mtp_for_draft: - _mtp_param, _ = self._collect_param(name, param) + # [MTPFp32MasterRead-v21] Root-cause fix for accept-rate + # decay: DistributedOptimizer.step() casts fp32 master + # params (main_param) to bf16 model params via + # _copy_main_params_to_model_params. + import os as _os_v21 + import torch as _torch_v21 + _master_read_on = ( + _os_v21.environ.get( + "AREAL_MTP_FP32_MASTER_READ", "1", + ) == "1" + ) + _fp32_full = None + if _master_read_on: + try: + _mp_shard = getattr(param, "main_param", None) + if ( + _mp_shard is not None + and isinstance(_mp_shard, _torch_v21.Tensor) + and _mp_shard.dtype == _torch_v21.float32 + ): + try: + _dp_group = mpu.get_data_parallel_group( + with_context_parallel=True, + ) + except TypeError: + _dp_group = ( + mpu.get_data_parallel_group() + ) + _dp_ws = _torch_v21.distributed.get_world_size( + group=_dp_group, + ) + if _dp_ws > 1: + _flat = _torch_v21.empty( + _mp_shard.numel() * _dp_ws, + dtype=_torch_v21.float32, + device=_mp_shard.device, + ) + _torch_v21.distributed.all_gather_into_tensor( + _flat, + _mp_shard.contiguous(), + group=_dp_group, + ) + else: + _flat = _mp_shard.contiguous() + # main_param shard is a flat view into the + # DP-sharded param buffer. DP all-gather + # produces the full flattened param; then + # slice+reshape to the model param shape. + _need = int(param.numel()) + if _flat.numel() >= _need: + _fp32_full = ( + _flat[:_need] + .view(param.shape) + .contiguous() + ) + # preserve TP sharding attributes so + # all_gather_param() treats this tensor + # the same as the bf16 model param. + if hasattr(param, "tensor_model_parallel"): + _fp32_full.tensor_model_parallel = ( + param.tensor_model_parallel + ) + if hasattr(param, "partition_dim"): + _fp32_full.partition_dim = ( + param.partition_dim + ) + if hasattr(param, "partition_stride"): + _fp32_full.partition_stride = ( + param.partition_stride + ) + self.logger.info( + "[MTPFp32MasterRead-v21 D15a] " + "name=%s dp_ws=%d fp32_full.shape=%s " + "fp32_abs_mean=%.6e fp32_abs_max=%.6e " + "(source=main_param)", + name, _dp_ws, + tuple(_fp32_full.shape), + float(_fp32_full.abs().mean().item()), + float(_fp32_full.abs().max().item()), + ) + else: + self.logger.warning( + "[MTPFp32MasterRead-v21] size " + "mismatch name=%s flat=%d need=%d; " + "falling back to bf16.", + name, int(_flat.numel()), _need, + ) + else: + if not getattr( + self, "_mtp_master_read_missing_warned", + False, + ): + self.logger.warning( + "[MTPFp32MasterRead-v21] " + "param.main_param unavailable for " + "name=%s (shard=%s); falling back " + "to bf16 model param. This will " + "re-expose the bf16-ULP flooring " + "root cause; ensure " + "DistributedOptimizer is used.", + name, + type(_mp_shard).__name__, + ) + self._mtp_master_read_missing_warned = ( + True + ) + except Exception as _e_v21: + self.logger.warning( + "[MTPFp32MasterRead-v21] error for " + "name=%s: %s; falling back to bf16.", + name, _e_v21, + ) + _fp32_full = None + if _fp32_full is not None: + _mtp_param, _ = self._collect_param( + name, _fp32_full, + ) + else: + _mtp_param, _ = self._collect_param(name, param) _mtp_model_name = self.hf_config.model_type _prev_count = len(mtp_hf_tensors) mtp_hf_tensors.extend( @@ -3519,23 +3642,23 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: fp8_direct_convert=self.fp8_direct_convert, ) ) - # Upcast MTP-draft + # [MTPBf16UpcastBroadcast-v16] Upcast MTP-draft # tensors to fp32 before serialization/broadcast so # sub-bf16-ULP weight deltas are not rounded away # on the wire. Complements upstream's NCCL-path # [MTPBroadcastDtype] upcast (which only covers # the distributed-weight-update path, not the # MTPSerialize/update_weights_from_tensor path). - # Gated on AREAL_MTP_FP32_BROADCAST=1. + # v21: default flipped 0->1. try: import os as _os_v16 _v16_on = ( _os_v16.environ.get( - "AREAL_MTP_FP32_BROADCAST", "0", + "AREAL_MTP_FP32_BROADCAST", "1", ) == "1" ) except Exception: - _v16_on = False + _v16_on = True if _v16_on: import torch as _torch_v16 _upcasted = 0 @@ -3571,6 +3694,35 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: "(name=%s).", _upcasted, name, ) + # [MTPWeightDeltaD15] Inter-version abs_mean delta + # tracker: proves the fp32-master-read fix works by + # showing non-zero abs_mean delta each version. Before + # the fix, abs_mean was frozen across 62 versions. A + # delta magnitude > 1e-7 confirms SGLang sees an + # actually-updated draft head. + if not hasattr(self, "_mtp_d15_prev_abs_mean"): + self._mtp_d15_prev_abs_mean = {} + for _hf_nm_d15, _hf_tn_d15 in mtp_hf_tensors[_prev_count:]: + _am_d15 = float( + _hf_tn_d15.float().abs().mean().item(), + ) + _prev_am = self._mtp_d15_prev_abs_mean.get(_hf_nm_d15) + if _prev_am is None: + _dlt = None + else: + _dlt = _am_d15 - _prev_am + self._mtp_d15_prev_abs_mean[_hf_nm_d15] = _am_d15 + self.logger.info( + "[MTPWeightDeltaD15] hf=%s abs_mean=%.9e " + "delta=%s frozen=%s (dtype=%s src=%s)", + _hf_nm_d15, _am_d15, + ("%+0.3e" % _dlt) if _dlt is not None else "n/a", + ("True" if _dlt is not None + and abs(_dlt) < 1e-9 else "False"), + str(_hf_tn_d15.dtype), + "fp32master" if _fp32_full is not None + else "bf16model", + ) # Diagnostic: log each converted MTP tensor with value # statistics for post-mortem debugging of weight corruption. for _hf_name, _hf_tensor in mtp_hf_tensors[_prev_count:]: From 918fb3f2dd827b4386fc727d2851085f410c8696 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 01:34:55 +0800 Subject: [PATCH 090/140] fix(megatron_engine): mtp nccl error --- areal/engine/megatron_engine.py | 505 ++++++++++++++++++++++---------- 1 file changed, 343 insertions(+), 162 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 9b8da551ab..0187752c8f 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -204,6 +204,7 @@ def __init__(self, config: TrainEngineConfig): "v20:SpecDecDiag(D01-D14 full pipeline instrumentation)", "v21:MTPFp32MasterRead+DefaultsOn" "(AREAL_MTP_FP32_MASTER_READ,AREAL_MTP_FP32_BROADCAST=1)", + "v22:CollectiveDeadlockFix+PreScan+EarlyFlush", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": @@ -3503,6 +3504,147 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: and self.is_pipeline_parallel_head() ) + # [MTPPreScan-v22] Early diagnostic pre-scan for MTP params. + # Runs on ALL ranks, before the 250+ param main loop. If the + # main loop later hangs on an MTP collective (as happened in + # spec_v1.log.1), this pre-scan still captures each MTP param's + # main_param availability / fp32 stats so the root-cause + # signal survives in the log. + try: + import os as _os_v22 + import sys as _sys_v22 + import torch as _torch_v22 + _v22_is_pp_head = self.is_pipeline_parallel_head() + _v22_supports_tu = getattr( + self, "_engine_supports_tensor_update", False, + ) + _v22_enable_mtp = bool( + getattr(self, "enable_mtp_training", False) + ) + _v22_collect = ( + _v22_enable_mtp + and _v22_supports_tu + and _v22_is_pp_head + ) + _v22_master_on = ( + _os_v22.environ.get( + "AREAL_MTP_FP32_MASTER_READ", "1", + ) == "1" + ) + _v22_bcast_on = ( + _os_v22.environ.get( + "AREAL_MTP_FP32_BROADCAST", "1", + ) == "1" + ) + self.logger.info( + "[MTPPreScan-v22] ENTRY rank=%d version=%s " + "is_pp_head=%s supports_tu=%s enable_mtp=%s " + "collect=%s master_on=%s fp32_bcast_on=%s", + dist.get_rank(), + str(getattr(meta, "version", "?")), + str(_v22_is_pp_head), str(_v22_supports_tu), + str(_v22_enable_mtp), str(_v22_collect), + str(_v22_master_on), str(_v22_bcast_on), + ) + try: + for _h in list(self.logger.handlers): + try: + _h.flush() + except Exception: + pass + _sys_v22.stdout.flush() + except Exception: + pass + _v22_mtp_seen = 0 + _v22_ok = 0 + _v22_missing = 0 + for _v22_nm, _v22_p in get_named_parameters( + self.model, num_moe_experts, + ): + if ".experts." in _v22_nm: + continue + if ".mtp." not in _v22_nm: + continue + _v22_mtp_seen += 1 + _v22_mp = getattr(_v22_p, "main_param", None) + _v22_kind = type(_v22_mp).__name__ + _v22_dtype = ( + str(_v22_mp.dtype) + if isinstance(_v22_mp, _torch_v22.Tensor) + else "n/a" + ) + _v22_shard_numel = ( + int(_v22_mp.numel()) + if isinstance(_v22_mp, _torch_v22.Tensor) + else -1 + ) + _v22_fp32_am = -1.0 + _v22_fp32_amax = -1.0 + try: + if ( + isinstance(_v22_mp, _torch_v22.Tensor) + and _v22_mp.dtype == _torch_v22.float32 + ): + _v22_ok += 1 + _v22_absf = _v22_mp.detach().abs() + _v22_fp32_am = float(_v22_absf.mean().item()) + _v22_fp32_amax = float(_v22_absf.max().item()) + else: + _v22_missing += 1 + except Exception: + _v22_missing += 1 + _v22_bf16_am = -1.0 + _v22_bf16_amax = -1.0 + try: + _v22_absb = _v22_p.detach().float().abs() + _v22_bf16_am = float(_v22_absb.mean().item()) + _v22_bf16_amax = float(_v22_absb.max().item()) + except Exception: + pass + self.logger.info( + "[MTPPreScan-v22] rank=%d name=%s " + "master_kind=%s master_dtype=%s " + "shard_numel=%d full_numel=%d " + "fp32_abs_mean=%.6e fp32_abs_max=%.6e " + "bf16_abs_mean=%.6e bf16_abs_max=%.6e " + "shape=%s", + dist.get_rank(), + _v22_nm, _v22_kind, _v22_dtype, + _v22_shard_numel, int(_v22_p.numel()), + _v22_fp32_am, _v22_fp32_amax, + _v22_bf16_am, _v22_bf16_amax, + tuple(_v22_p.shape), + ) + try: + for _h in list(self.logger.handlers): + try: + _h.flush() + except Exception: + pass + _sys_v22.stdout.flush() + except Exception: + pass + self.logger.info( + "[MTPPreScan-v22] SUMMARY rank=%d version=%s " + "mtp_params=%d master_ok=%d master_missing=%d", + dist.get_rank(), + str(getattr(meta, "version", "?")), + _v22_mtp_seen, _v22_ok, _v22_missing, + ) + try: + for _h in list(self.logger.handlers): + try: + _h.flush() + except Exception: + pass + _sys_v22.stdout.flush() + except Exception: + pass + except Exception as _e_v22: + self.logger.warning( + "[MTPPreScan-v22] aborted: %s", _e_v22, + ) + _param_idx = 0 for name, param in get_named_parameters(self.model, num_moe_experts): if ".experts." in name: @@ -3510,120 +3652,168 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if ".mtp." in name: mtp_param_count += 1 mtp_param_bytes += param.numel() * param.element_size() - if _collect_mtp_for_draft: - # [MTPFp32MasterRead-v21] Root-cause fix for accept-rate - # decay: DistributedOptimizer.step() casts fp32 master - # params (main_param) to bf16 model params via - # _copy_main_params_to_model_params. - import os as _os_v21 - import torch as _torch_v21 - _master_read_on = ( - _os_v21.environ.get( - "AREAL_MTP_FP32_MASTER_READ", "1", - ) == "1" - ) - _fp32_full = None - if _master_read_on: + # [MTPFp32MasterRead-v22 CollectiveSafe] + # Root-cause fix: DistributedOptimizer.step() casts + # fp32 master (main_param) to bf16 model param, and at + # |W|~0.4 the bf16 ULP is ~7.78e-3 while per-step fp32 + # deltas are 1e-6..1e-4, so SGLang saw a frozen draft. + # v21 tried to read the fp32 shard + DP-all_gather + # inside `if _collect_mtp_for_draft` (pp-head only) + # which DEADLOCKED because non-pp-head DP peers never + # joined the DP collective (they took the `else` branch + # calling TP-only all_gather_param via _collect_param). + # v22 moves the DP all_gather OUTSIDE the pp-head gate + # so every DP rank participates symmetrically. Only the + # pp-head rank actually uses the gathered fp32 buffer + # for serialize; non-pp peers still must drive + # _collect_param() afterwards for TP-side symmetry. + import os as _os_v22m + import sys as _sys_v22m + import torch as _torch_v22m + _mp_on = ( + _os_v22m.environ.get( + "AREAL_MTP_FP32_MASTER_READ", "1", + ) == "1" + ) + _fp32_full = None + if _mp_on: + try: try: - _mp_shard = getattr(param, "main_param", None) - if ( - _mp_shard is not None - and isinstance(_mp_shard, _torch_v21.Tensor) - and _mp_shard.dtype == _torch_v21.float32 - ): + _dp_group = mpu.get_data_parallel_group( + with_context_parallel=True, + ) + except TypeError: + _dp_group = mpu.get_data_parallel_group() + _dp_ws = _torch_v22m.distributed.get_world_size( + group=_dp_group, + ) + _mp_shard = getattr(param, "main_param", None) + _have_master = ( + isinstance(_mp_shard, _torch_v22m.Tensor) + and _mp_shard.dtype == _torch_v22m.float32 + ) + self.logger.info( + "[MTPFp32MasterRead-v22 ENTER] rank=%d " + "name=%s dp_ws=%d have_master=%s " + "shard_numel=%s need_numel=%d", + dist.get_rank(), name, _dp_ws, + str(_have_master), + (str(int(_mp_shard.numel())) + if isinstance( + _mp_shard, _torch_v22m.Tensor) + else "n/a"), + int(param.numel()), + ) + try: + for _h in list(self.logger.handlers): try: - _dp_group = mpu.get_data_parallel_group( - with_context_parallel=True, - ) - except TypeError: - _dp_group = ( - mpu.get_data_parallel_group() - ) - _dp_ws = _torch_v21.distributed.get_world_size( - group=_dp_group, + _h.flush() + except Exception: + pass + _sys_v22m.stdout.flush() + except Exception: + pass + if _have_master and _dp_ws > 1: + # COLLECTIVE: all DP peers must enter. + _flat = _torch_v22m.empty( + _mp_shard.numel() * _dp_ws, + dtype=_torch_v22m.float32, + device=_mp_shard.device, + ) + _torch_v22m.distributed.all_gather_into_tensor( + _flat, + _mp_shard.contiguous(), + group=_dp_group, + ) + _need = int(param.numel()) + if _flat.numel() >= _need: + _fp32_full = ( + _flat[:_need] + .view(param.shape) + .contiguous() ) - if _dp_ws > 1: - _flat = _torch_v21.empty( - _mp_shard.numel() * _dp_ws, - dtype=_torch_v21.float32, - device=_mp_shard.device, + # preserve TP-sharding attrs for + # downstream _collect_param's TP + # all_gather_param semantics. + if hasattr(param, "tensor_model_parallel"): + _fp32_full.tensor_model_parallel = ( + param.tensor_model_parallel ) - _torch_v21.distributed.all_gather_into_tensor( - _flat, - _mp_shard.contiguous(), - group=_dp_group, + if hasattr(param, "partition_dim"): + _fp32_full.partition_dim = ( + param.partition_dim ) - else: - _flat = _mp_shard.contiguous() - # main_param shard is a flat view into the - # DP-sharded param buffer. DP all-gather - # produces the full flattened param; then - # slice+reshape to the model param shape. - _need = int(param.numel()) - if _flat.numel() >= _need: - _fp32_full = ( - _flat[:_need] - .view(param.shape) - .contiguous() - ) - # preserve TP sharding attributes so - # all_gather_param() treats this tensor - # the same as the bf16 model param. - if hasattr(param, "tensor_model_parallel"): - _fp32_full.tensor_model_parallel = ( - param.tensor_model_parallel - ) - if hasattr(param, "partition_dim"): - _fp32_full.partition_dim = ( - param.partition_dim - ) - if hasattr(param, "partition_stride"): - _fp32_full.partition_stride = ( - param.partition_stride - ) - self.logger.info( - "[MTPFp32MasterRead-v21 D15a] " - "name=%s dp_ws=%d fp32_full.shape=%s " - "fp32_abs_mean=%.6e fp32_abs_max=%.6e " - "(source=main_param)", - name, _dp_ws, - tuple(_fp32_full.shape), - float(_fp32_full.abs().mean().item()), - float(_fp32_full.abs().max().item()), - ) - else: - self.logger.warning( - "[MTPFp32MasterRead-v21] size " - "mismatch name=%s flat=%d need=%d; " - "falling back to bf16.", - name, int(_flat.numel()), _need, + if hasattr(param, "partition_stride"): + _fp32_full.partition_stride = ( + param.partition_stride ) else: - if not getattr( - self, "_mtp_master_read_missing_warned", - False, - ): - self.logger.warning( - "[MTPFp32MasterRead-v21] " - "param.main_param unavailable for " - "name=%s (shard=%s); falling back " - "to bf16 model param. This will " - "re-expose the bf16-ULP flooring " - "root cause; ensure " - "DistributedOptimizer is used.", - name, - type(_mp_shard).__name__, + self.logger.warning( + "[MTPFp32MasterRead-v22] size " + "mismatch name=%s flat=%d need=%d;" + " falling back to bf16.", + name, int(_flat.numel()), _need, + ) + elif _have_master and _dp_ws == 1: + # DP=1: no collective, shard == full. + _fp32_full = ( + _mp_shard.view(param.shape).contiguous() + if _mp_shard.numel() == param.numel() + else None + ) + if _fp32_full is not None: + if hasattr(param, "tensor_model_parallel"): + _fp32_full.tensor_model_parallel = ( + param.tensor_model_parallel ) - self._mtp_master_read_missing_warned = ( - True + if hasattr(param, "partition_dim"): + _fp32_full.partition_dim = ( + param.partition_dim ) - except Exception as _e_v21: - self.logger.warning( - "[MTPFp32MasterRead-v21] error for " - "name=%s: %s; falling back to bf16.", - name, _e_v21, + if hasattr(param, "partition_stride"): + _fp32_full.partition_stride = ( + param.partition_stride + ) + else: + # No main_param -> rely on bf16 fallback. + # Every DP rank hits same branch -> safe. + if not getattr( + self, + "_mtp_master_read_missing_warned", + False, + ): + self.logger.warning( + "[MTPFp32MasterRead-v22] " + "param.main_param unavailable " + "(name=%s, kind=%s); falling back " + "to bf16 model param.", + name, + type(_mp_shard).__name__, + ) + self._mtp_master_read_missing_warned = True + if _fp32_full is not None: + self.logger.info( + "[MTPFp32MasterRead-v22 D15a] " + "rank=%d name=%s dp_ws=%d shape=%s " + "fp32_abs_mean=%.6e fp32_abs_max=%.6e " + "(source=main_param)", + dist.get_rank(), name, _dp_ws, + tuple(_fp32_full.shape), + float(_fp32_full.abs().mean().item()), + float(_fp32_full.abs().max().item()), ) - _fp32_full = None + except Exception as _e_v22m: + self.logger.warning( + "[MTPFp32MasterRead-v22] error " + "name=%s: %s; falling back to bf16.", + name, _e_v22m, + ) + _fp32_full = None + # Decide the source tensor for _collect_param. On + # pp-head rank we prefer the fp32 full buffer. On + # non-pp-head ranks we still must drive _collect_param + # so TP all_gather_param is called symmetrically. + if _collect_mtp_for_draft: if _fp32_full is not None: _mtp_param, _ = self._collect_param( name, _fp32_full, @@ -3642,100 +3832,88 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: fp8_direct_convert=self.fp8_direct_convert, ) ) - # [MTPBf16UpcastBroadcast-v16] Upcast MTP-draft - # tensors to fp32 before serialization/broadcast so - # sub-bf16-ULP weight deltas are not rounded away - # on the wire. Complements upstream's NCCL-path - # [MTPBroadcastDtype] upcast (which only covers - # the distributed-weight-update path, not the - # MTPSerialize/update_weights_from_tensor path). - # v21: default flipped 0->1. + # [MTPBf16UpcastBroadcast-v22] Upcast bf16->fp32 + # before serialize so sub-ULP deltas are not + # rounded on the wire (default flipped 0->1). try: - import os as _os_v16 _v16_on = ( - _os_v16.environ.get( + _os_v22m.environ.get( "AREAL_MTP_FP32_BROADCAST", "1", ) == "1" ) except Exception: _v16_on = True if _v16_on: - import torch as _torch_v16 _upcasted = 0 - _d12_sample = None for _i in range(_prev_count, len(mtp_hf_tensors)): _nm_v16, _tn_v16 = mtp_hf_tensors[_i] - if _d12_sample is None: - try: - _d12_sample = ( - _nm_v16, str(_tn_v16.dtype), - float(_tn_v16.float().abs() - .max().item()), - ) - except Exception: - _d12_sample = (_nm_v16, "n/a", 0.0) - if _tn_v16.dtype == _torch_v16.bfloat16: + if _tn_v16.dtype == _torch_v22m.bfloat16: mtp_hf_tensors[_i] = ( _nm_v16, _tn_v16.float().contiguous(), ) _upcasted += 1 - if _d12_sample is not None: - self.logger.info( - "[SpecDecDiag-v20 D12] pre-upcast-sample " - "hf=%s dtype=%s |W|_max=%.3e upcasted=%d", - _d12_sample[0], _d12_sample[1], - _d12_sample[2], _upcasted, - ) if _upcasted > 0: self.logger.info( - "[MTPBf16UpcastBroadcast-v16] Upcast %d MTP " - "tensors bf16->fp32 at MTPSerialize path " - "(name=%s).", + "[MTPBf16UpcastBroadcast-v22] Upcast %d " + "MTP tensors bf16->fp32 (name=%s).", _upcasted, name, ) - # [MTPWeightDeltaD15] Inter-version abs_mean delta - # tracker: proves the fp32-master-read fix works by - # showing non-zero abs_mean delta each version. Before - # the fix, abs_mean was frozen across 62 versions. A - # delta magnitude > 1e-7 confirms SGLang sees an - # actually-updated draft head. + # [MTPWeightDeltaD15] Inter-version abs_mean + # delta tracker. Proves the fp32-master-read + # fix works: non-zero delta / frozen=False + # across consecutive versions. if not hasattr(self, "_mtp_d15_prev_abs_mean"): self._mtp_d15_prev_abs_mean = {} - for _hf_nm_d15, _hf_tn_d15 in mtp_hf_tensors[_prev_count:]: + for _hf_nm_d15, _hf_tn_d15 in ( + mtp_hf_tensors[_prev_count:] + ): _am_d15 = float( _hf_tn_d15.float().abs().mean().item(), ) - _prev_am = self._mtp_d15_prev_abs_mean.get(_hf_nm_d15) - if _prev_am is None: - _dlt = None - else: - _dlt = _am_d15 - _prev_am - self._mtp_d15_prev_abs_mean[_hf_nm_d15] = _am_d15 + _prev_am = self._mtp_d15_prev_abs_mean.get( + _hf_nm_d15, + ) + _dlt = ( + None if _prev_am is None + else _am_d15 - _prev_am + ) + self._mtp_d15_prev_abs_mean[_hf_nm_d15] = ( + _am_d15 + ) self.logger.info( - "[MTPWeightDeltaD15] hf=%s abs_mean=%.9e " - "delta=%s frozen=%s (dtype=%s src=%s)", + "[MTPWeightDeltaD15] hf=%s " + "abs_mean=%.9e delta=%s frozen=%s " + "(dtype=%s src=%s)", _hf_nm_d15, _am_d15, - ("%+0.3e" % _dlt) if _dlt is not None else "n/a", + (("%+0.3e" % _dlt) + if _dlt is not None else "n/a"), ("True" if _dlt is not None and abs(_dlt) < 1e-9 else "False"), str(_hf_tn_d15.dtype), - "fp32master" if _fp32_full is not None - else "bf16model", + ("fp32master" if _fp32_full is not None + else "bf16model"), ) - # Diagnostic: log each converted MTP tensor with value - # statistics for post-mortem debugging of weight corruption. - for _hf_name, _hf_tensor in mtp_hf_tensors[_prev_count:]: + try: + for _h in list(self.logger.handlers): + try: + _h.flush() + except Exception: + pass + _sys_v22m.stdout.flush() + except Exception: + pass + # Diagnostic: per-tensor stats. + for _hf_name, _hf_tensor in ( + mtp_hf_tensors[_prev_count:] + ): _abs = _hf_tensor.float().abs() - # [MTPWeightDeltaGuard-v14] Flag all-zero MTP - # tensors explicitly so draft-head stall is - # surfaced independently of MTPWeightDiag. try: if float(_abs.max().item()) == 0.0: self.logger.warning( "[MTPWeightDeltaGuard-v14] MTP " - "tensor %s (hf=%s) has abs_max==0; " - "draft head is stalled this step.", + "tensor %s (hf=%s) has abs_max" + "==0; draft head is stalled.", name, _hf_name, ) except Exception: @@ -3751,6 +3929,9 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"norm={_hf_tensor.float().norm().item():.6e}" ) else: + # non-pp-head rank still drives TP all_gather for + # symmetry (unchanged behaviour). v22 ensures this + # rank already joined the DP collective above. self._collect_param(name, param) continue if self.config.use_lora and ( From 8913215fbb2c4b46aff219db45500975d70a9eee Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 02:09:26 +0800 Subject: [PATCH 091/140] fix(engine): fix again --- areal/engine/megatron_engine.py | 273 +++++++++++++++++--------------- 1 file changed, 149 insertions(+), 124 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 0187752c8f..ab7cdc2dba 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -205,6 +205,7 @@ def __init__(self, config: TrainEngineConfig): "v21:MTPFp32MasterRead+DefaultsOn" "(AREAL_MTP_FP32_MASTER_READ,AREAL_MTP_FP32_BROADCAST=1)", "v22:CollectiveDeadlockFix+PreScan+EarlyFlush", + "v23:NoDPCollective+TPConsensus+MainParamViewDirect", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": @@ -3505,11 +3506,9 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) # [MTPPreScan-v22] Early diagnostic pre-scan for MTP params. - # Runs on ALL ranks, before the 250+ param main loop. If the - # main loop later hangs on an MTP collective (as happened in - # spec_v1.log.1), this pre-scan still captures each MTP param's - # main_param availability / fp32 stats so the root-cause - # signal survives in the log. + # Runs on ALL ranks before the main param loop so that each + # MTP param's main_param availability / fp32 stats survive + # even if the later loop hangs. try: import os as _os_v22 import sys as _sys_v22 @@ -3652,55 +3651,71 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if ".mtp." in name: mtp_param_count += 1 mtp_param_bytes += param.numel() * param.element_size() - # [MTPFp32MasterRead-v22 CollectiveSafe] - # Root-cause fix: DistributedOptimizer.step() casts - # fp32 master (main_param) to bf16 model param, and at - # |W|~0.4 the bf16 ULP is ~7.78e-3 while per-step fp32 - # deltas are 1e-6..1e-4, so SGLang saw a frozen draft. - # v21 tried to read the fp32 shard + DP-all_gather - # inside `if _collect_mtp_for_draft` (pp-head only) - # which DEADLOCKED because non-pp-head DP peers never - # joined the DP collective (they took the `else` branch - # calling TP-only all_gather_param via _collect_param). - # v22 moves the DP all_gather OUTSIDE the pp-head gate - # so every DP rank participates symmetrically. Only the - # pp-head rank actually uses the gathered fp32 buffer - # for serialize; non-pp peers still must drive - # _collect_param() afterwards for TP-side symmetry. - import os as _os_v22m - import sys as _sys_v22m - import torch as _torch_v22m + # NoDPCollective path. + # Root cause of the iter10 hang: v22 issued a DP + # all_gather_into_tensor under `if _have_master` + # which is True only on the DP rank that *owns* + # the fp32 master (DistributedOptimizer assigns + # each param-bucket to ONE DP rank), so non-owning + # DP peers skipped the collective -> NCCL hang. + # PreScan-v22 proved shard_numel == full_numel + # on owning rank, i.e. main_param is ALREADY the + # full TP-shard. Therefore the DP all_gather is + # unnecessary and removed. + # + # TP-safety: DistributedOptimizer ownership is + # per-bucket by DP rank, so all TP peers of the + # owning DP rank share the same ownership status. + # A TP-group MIN all_reduce on the have_master + # bool provides a belt-and-braces sanity check. + import os as _os_v23m + import sys as _sys_v23m + import torch as _torch_v23m _mp_on = ( - _os_v22m.environ.get( + _os_v23m.environ.get( "AREAL_MTP_FP32_MASTER_READ", "1", ) == "1" ) _fp32_full = None + _src_tag = "bf16model" if _mp_on: try: + _mp_shard = getattr(param, "main_param", None) + _have_master_local = ( + isinstance(_mp_shard, _torch_v23m.Tensor) + and _mp_shard.dtype == _torch_v23m.float32 + and int(_mp_shard.numel()) + == int(param.numel()) + ) + # Fetch group handles (DP info only for log). try: _dp_group = mpu.get_data_parallel_group( with_context_parallel=True, ) except TypeError: _dp_group = mpu.get_data_parallel_group() - _dp_ws = _torch_v22m.distributed.get_world_size( + _dp_ws = _torch_v23m.distributed.get_world_size( group=_dp_group, ) - _mp_shard = getattr(param, "main_param", None) - _have_master = ( - isinstance(_mp_shard, _torch_v22m.Tensor) - and _mp_shard.dtype == _torch_v22m.float32 + try: + _tp_group = mpu.get_tensor_model_parallel_group() + except Exception: + _tp_group = None + _tp_ws = ( + _torch_v23m.distributed.get_world_size( + group=_tp_group) + if _tp_group is not None else 1 ) self.logger.info( - "[MTPFp32MasterRead-v22 ENTER] rank=%d " - "name=%s dp_ws=%d have_master=%s " - "shard_numel=%s need_numel=%d", - dist.get_rank(), name, _dp_ws, - str(_have_master), + "[MTPFp32MasterRead-v23 ENTER] rank=%d " + "name=%s dp_ws=%d tp_ws=%d " + "have_master_local=%s shard_numel=%s " + "need_numel=%d", + dist.get_rank(), name, _dp_ws, _tp_ws, + str(_have_master_local), (str(int(_mp_shard.numel())) if isinstance( - _mp_shard, _torch_v22m.Tensor) + _mp_shard, _torch_v23m.Tensor) else "n/a"), int(param.numel()), ) @@ -3710,109 +3725,123 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _h.flush() except Exception: pass - _sys_v22m.stdout.flush() + _sys_v23m.stdout.flush() except Exception: pass - if _have_master and _dp_ws > 1: - # COLLECTIVE: all DP peers must enter. - _flat = _torch_v22m.empty( - _mp_shard.numel() * _dp_ws, - dtype=_torch_v22m.float32, - device=_mp_shard.device, + # TP-group MIN all_reduce on have_master bool. + # All TP peers within the same TP group MUST + # call this collective together (they do: we + # are outside any `if _have_master` gate). + _have_master_tp = _have_master_local + if _tp_group is not None and _tp_ws > 1: + _dev = ( + _mp_shard.device + if isinstance( + _mp_shard, _torch_v23m.Tensor) + else param.device + ) + _hv = _torch_v23m.tensor( + [1 if _have_master_local else 0], + dtype=_torch_v23m.int32, + device=_dev, ) - _torch_v22m.distributed.all_gather_into_tensor( - _flat, - _mp_shard.contiguous(), - group=_dp_group, + _torch_v23m.distributed.all_reduce( + _hv, + op=( + _torch_v23m.distributed + .ReduceOp.MIN + ), + group=_tp_group, ) - _need = int(param.numel()) - if _flat.numel() >= _need: - _fp32_full = ( - _flat[:_need] - .view(param.shape) - .contiguous() + _have_master_tp = bool(int(_hv.item()) == 1) + if _have_master_tp != _have_master_local: + self.logger.warning( + "[MTPFp32MasterRead-v23 " + "TP_CONSENSUS] rank=%d name=%s " + "local=%s consensus=%s; " + "falling back to bf16 on this " + "TP peer to keep _collect_param " + "symmetric.", + dist.get_rank(), name, + str(_have_master_local), + str(_have_master_tp), ) - # preserve TP-sharding attrs for - # downstream _collect_param's TP - # all_gather_param semantics. - if hasattr(param, "tensor_model_parallel"): - _fp32_full.tensor_model_parallel = ( - param.tensor_model_parallel - ) - if hasattr(param, "partition_dim"): - _fp32_full.partition_dim = ( - param.partition_dim - ) - if hasattr(param, "partition_stride"): - _fp32_full.partition_stride = ( - param.partition_stride - ) else: - self.logger.warning( - "[MTPFp32MasterRead-v22] size " - "mismatch name=%s flat=%d need=%d;" - " falling back to bf16.", - name, int(_flat.numel()), _need, + self.logger.info( + "[MTPFp32MasterRead-v23 " + "TP_CONSENSUS] rank=%d name=%s " + "consensus=%s", + dist.get_rank(), name, + str(_have_master_tp), ) - elif _have_master and _dp_ws == 1: - # DP=1: no collective, shard == full. + if _have_master_tp and _have_master_local: + # Direct view: main_param is already the + # full TP-shard (shard_numel==full_numel). _fp32_full = ( - _mp_shard.view(param.shape).contiguous() - if _mp_shard.numel() == param.numel() - else None + _mp_shard.view(param.shape) + .contiguous() ) - if _fp32_full is not None: - if hasattr(param, "tensor_model_parallel"): - _fp32_full.tensor_model_parallel = ( - param.tensor_model_parallel - ) - if hasattr(param, "partition_dim"): - _fp32_full.partition_dim = ( - param.partition_dim - ) - if hasattr(param, "partition_stride"): - _fp32_full.partition_stride = ( - param.partition_stride - ) + if hasattr(param, "tensor_model_parallel"): + _fp32_full.tensor_model_parallel = ( + param.tensor_model_parallel + ) + if hasattr(param, "partition_dim"): + _fp32_full.partition_dim = ( + param.partition_dim + ) + if hasattr(param, "partition_stride"): + _fp32_full.partition_stride = ( + param.partition_stride + ) + _src_tag = "fp32master" else: - # No main_param -> rely on bf16 fallback. - # Every DP rank hits same branch -> safe. if not getattr( self, "_mtp_master_read_missing_warned", False, ): self.logger.warning( - "[MTPFp32MasterRead-v22] " - "param.main_param unavailable " - "(name=%s, kind=%s); falling back " - "to bf16 model param.", - name, + "[MTPFp32MasterRead-v23] " + "param.main_param unavailable on " + "this TP-group (rank=%d, " + "name=%s, kind=%s); using bf16 " + "model param.", + dist.get_rank(), name, type(_mp_shard).__name__, ) self._mtp_master_read_missing_warned = True if _fp32_full is not None: self.logger.info( - "[MTPFp32MasterRead-v22 D15a] " - "rank=%d name=%s dp_ws=%d shape=%s " - "fp32_abs_mean=%.6e fp32_abs_max=%.6e " - "(source=main_param)", - dist.get_rank(), name, _dp_ws, + "[MTPFp32MasterRead-v23 D15a] " + "rank=%d name=%s dp_ws=%d tp_ws=%d " + "shape=%s fp32_abs_mean=%.6e " + "fp32_abs_max=%.6e (source=%s)", + dist.get_rank(), name, _dp_ws, _tp_ws, tuple(_fp32_full.shape), float(_fp32_full.abs().mean().item()), float(_fp32_full.abs().max().item()), + _src_tag, ) - except Exception as _e_v22m: + try: + for _h in list(self.logger.handlers): + try: + _h.flush() + except Exception: + pass + _sys_v23m.stdout.flush() + except Exception: + pass + except Exception as _e_v23m: self.logger.warning( - "[MTPFp32MasterRead-v22] error " + "[MTPFp32MasterRead-v23] error " "name=%s: %s; falling back to bf16.", - name, _e_v22m, + name, _e_v23m, ) _fp32_full = None - # Decide the source tensor for _collect_param. On - # pp-head rank we prefer the fp32 full buffer. On - # non-pp-head ranks we still must drive _collect_param - # so TP all_gather_param is called symmetrically. + _src_tag = "bf16model" + # Feed _collect_param: pp-head uses the gathered + # tensors for serialize; non-pp peers drive it too + # so TP all_gather_param stays symmetric. if _collect_mtp_for_draft: if _fp32_full is not None: _mtp_param, _ = self._collect_param( @@ -3832,12 +3861,12 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: fp8_direct_convert=self.fp8_direct_convert, ) ) - # [MTPBf16UpcastBroadcast-v22] Upcast bf16->fp32 + # [MTPBf16UpcastBroadcast-v23] Upcast bf16->fp32 # before serialize so sub-ULP deltas are not - # rounded on the wire (default flipped 0->1). + # rounded on the wire (default 1). try: _v16_on = ( - _os_v22m.environ.get( + _os_v23m.environ.get( "AREAL_MTP_FP32_BROADCAST", "1", ) == "1" ) @@ -3847,7 +3876,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _upcasted = 0 for _i in range(_prev_count, len(mtp_hf_tensors)): _nm_v16, _tn_v16 = mtp_hf_tensors[_i] - if _tn_v16.dtype == _torch_v22m.bfloat16: + if _tn_v16.dtype == _torch_v23m.bfloat16: mtp_hf_tensors[_i] = ( _nm_v16, _tn_v16.float().contiguous(), @@ -3855,14 +3884,12 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _upcasted += 1 if _upcasted > 0: self.logger.info( - "[MTPBf16UpcastBroadcast-v22] Upcast %d " + "[MTPBf16UpcastBroadcast-v23] Upcast %d " "MTP tensors bf16->fp32 (name=%s).", _upcasted, name, ) - # [MTPWeightDeltaD15] Inter-version abs_mean - # delta tracker. Proves the fp32-master-read - # fix works: non-zero delta / frozen=False - # across consecutive versions. + # [MTPWeightDeltaD15] version-to-version + # abs_mean delta tracker. if not hasattr(self, "_mtp_d15_prev_abs_mean"): self._mtp_d15_prev_abs_mean = {} for _hf_nm_d15, _hf_tn_d15 in ( @@ -3891,8 +3918,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ("True" if _dlt is not None and abs(_dlt) < 1e-9 else "False"), str(_hf_tn_d15.dtype), - ("fp32master" if _fp32_full is not None - else "bf16model"), + _src_tag, ) try: for _h in list(self.logger.handlers): @@ -3900,10 +3926,10 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _h.flush() except Exception: pass - _sys_v22m.stdout.flush() + _sys_v23m.stdout.flush() except Exception: pass - # Diagnostic: per-tensor stats. + # Per-tensor stats. for _hf_name, _hf_tensor in ( mtp_hf_tensors[_prev_count:] ): @@ -3930,8 +3956,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) else: # non-pp-head rank still drives TP all_gather for - # symmetry (unchanged behaviour). v22 ensures this - # rank already joined the DP collective above. + # symmetry. self._collect_param(name, param) continue if self.config.use_lora and ( From ba01036ac2588b990146ab4b573f7ec045ce954e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 02:32:46 +0800 Subject: [PATCH 092/140] fix(megatron_engine): fix --- areal/engine/megatron_engine.py | 93 ++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 43 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ab7cdc2dba..0bbc69e883 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -206,6 +206,7 @@ def __init__(self, config: TrainEngineConfig): "(AREAL_MTP_FP32_MASTER_READ,AREAL_MTP_FP32_BROADCAST=1)", "v22:CollectiveDeadlockFix+PreScan+EarlyFlush", "v23:NoDPCollective+TPConsensus+MainParamViewDirect", + "v24:TPDtypeSymmetric+NonPPHeadAlsoFp32", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": @@ -3651,7 +3652,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if ".mtp." in name: mtp_param_count += 1 mtp_param_bytes += param.numel() * param.element_size() - # NoDPCollective path. + # [MTPFp32MasterRead-v24] TPDtypeSymmetric path. # Root cause of the iter10 hang: v22 issued a DP # all_gather_into_tensor under `if _have_master` # which is True only on the DP rank that *owns* @@ -3668,11 +3669,21 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: # owning DP rank share the same ownership status. # A TP-group MIN all_reduce on the have_master # bool provides a belt-and-braces sanity check. - import os as _os_v23m - import sys as _sys_v23m - import torch as _torch_v23m + # + # v24 change: ALL ranks (pp-head or not) call + # _collect_param with the SAME source tensor + # (fp32_full or bf16 param) so the TP + # all_gather_param inside _collect_param sees a + # consistent dtype across TP peers. v23 had a + # subtle bug: pp-head passed fp32 while non-pp + # passed bf16 -> dtype mismatch inside + # all_gather_param -> silent data corruption or + # NCCL dtype error. + import os as _os_v24m + import sys as _sys_v24m + import torch as _torch_v24m _mp_on = ( - _os_v23m.environ.get( + _os_v24m.environ.get( "AREAL_MTP_FP32_MASTER_READ", "1", ) == "1" ) @@ -3682,8 +3693,8 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: try: _mp_shard = getattr(param, "main_param", None) _have_master_local = ( - isinstance(_mp_shard, _torch_v23m.Tensor) - and _mp_shard.dtype == _torch_v23m.float32 + isinstance(_mp_shard, _torch_v24m.Tensor) + and _mp_shard.dtype == _torch_v24m.float32 and int(_mp_shard.numel()) == int(param.numel()) ) @@ -3694,7 +3705,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except TypeError: _dp_group = mpu.get_data_parallel_group() - _dp_ws = _torch_v23m.distributed.get_world_size( + _dp_ws = _torch_v24m.distributed.get_world_size( group=_dp_group, ) try: @@ -3702,12 +3713,12 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: except Exception: _tp_group = None _tp_ws = ( - _torch_v23m.distributed.get_world_size( + _torch_v24m.distributed.get_world_size( group=_tp_group) if _tp_group is not None else 1 ) self.logger.info( - "[MTPFp32MasterRead-v23 ENTER] rank=%d " + "[MTPFp32MasterRead-v24 ENTER] rank=%d " "name=%s dp_ws=%d tp_ws=%d " "have_master_local=%s shard_numel=%s " "need_numel=%d", @@ -3715,7 +3726,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: str(_have_master_local), (str(int(_mp_shard.numel())) if isinstance( - _mp_shard, _torch_v23m.Tensor) + _mp_shard, _torch_v24m.Tensor) else "n/a"), int(param.numel()), ) @@ -3725,7 +3736,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _h.flush() except Exception: pass - _sys_v23m.stdout.flush() + _sys_v24m.stdout.flush() except Exception: pass # TP-group MIN all_reduce on have_master bool. @@ -3737,18 +3748,18 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _dev = ( _mp_shard.device if isinstance( - _mp_shard, _torch_v23m.Tensor) + _mp_shard, _torch_v24m.Tensor) else param.device ) - _hv = _torch_v23m.tensor( + _hv = _torch_v24m.tensor( [1 if _have_master_local else 0], - dtype=_torch_v23m.int32, + dtype=_torch_v24m.int32, device=_dev, ) - _torch_v23m.distributed.all_reduce( + _torch_v24m.distributed.all_reduce( _hv, op=( - _torch_v23m.distributed + _torch_v24m.distributed .ReduceOp.MIN ), group=_tp_group, @@ -3756,7 +3767,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _have_master_tp = bool(int(_hv.item()) == 1) if _have_master_tp != _have_master_local: self.logger.warning( - "[MTPFp32MasterRead-v23 " + "[MTPFp32MasterRead-v24 " "TP_CONSENSUS] rank=%d name=%s " "local=%s consensus=%s; " "falling back to bf16 on this " @@ -3768,7 +3779,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) else: self.logger.info( - "[MTPFp32MasterRead-v23 " + "[MTPFp32MasterRead-v24 " "TP_CONSENSUS] rank=%d name=%s " "consensus=%s", dist.get_rank(), name, @@ -3801,7 +3812,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: False, ): self.logger.warning( - "[MTPFp32MasterRead-v23] " + "[MTPFp32MasterRead-v24] " "param.main_param unavailable on " "this TP-group (rank=%d, " "name=%s, kind=%s); using bf16 " @@ -3812,7 +3823,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: self._mtp_master_read_missing_warned = True if _fp32_full is not None: self.logger.info( - "[MTPFp32MasterRead-v23 D15a] " + "[MTPFp32MasterRead-v24 D15a] " "rank=%d name=%s dp_ws=%d tp_ws=%d " "shape=%s fp32_abs_mean=%.6e " "fp32_abs_max=%.6e (source=%s)", @@ -3828,27 +3839,27 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _h.flush() except Exception: pass - _sys_v23m.stdout.flush() + _sys_v24m.stdout.flush() except Exception: pass - except Exception as _e_v23m: + except Exception as _e_v24m: self.logger.warning( - "[MTPFp32MasterRead-v23] error " + "[MTPFp32MasterRead-v24] error " "name=%s: %s; falling back to bf16.", - name, _e_v23m, + name, _e_v24m, ) _fp32_full = None _src_tag = "bf16model" - # Feed _collect_param: pp-head uses the gathered - # tensors for serialize; non-pp peers drive it too - # so TP all_gather_param stays symmetric. + # v24: ALL ranks call _collect_param with the SAME + # source tensor (fp32_full or bf16 param) so the TP + # all_gather_param inside sees consistent dtype. + if _fp32_full is not None: + _mtp_param, _ = self._collect_param( + name, _fp32_full, + ) + else: + _mtp_param, _ = self._collect_param(name, param) if _collect_mtp_for_draft: - if _fp32_full is not None: - _mtp_param, _ = self._collect_param( - name, _fp32_full, - ) - else: - _mtp_param, _ = self._collect_param(name, param) _mtp_model_name = self.hf_config.model_type _prev_count = len(mtp_hf_tensors) mtp_hf_tensors.extend( @@ -3861,12 +3872,12 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: fp8_direct_convert=self.fp8_direct_convert, ) ) - # [MTPBf16UpcastBroadcast-v23] Upcast bf16->fp32 + # [MTPBf16UpcastBroadcast-v24] Upcast bf16->fp32 # before serialize so sub-ULP deltas are not # rounded on the wire (default 1). try: _v16_on = ( - _os_v23m.environ.get( + _os_v24m.environ.get( "AREAL_MTP_FP32_BROADCAST", "1", ) == "1" ) @@ -3876,7 +3887,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _upcasted = 0 for _i in range(_prev_count, len(mtp_hf_tensors)): _nm_v16, _tn_v16 = mtp_hf_tensors[_i] - if _tn_v16.dtype == _torch_v23m.bfloat16: + if _tn_v16.dtype == _torch_v24m.bfloat16: mtp_hf_tensors[_i] = ( _nm_v16, _tn_v16.float().contiguous(), @@ -3884,7 +3895,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _upcasted += 1 if _upcasted > 0: self.logger.info( - "[MTPBf16UpcastBroadcast-v23] Upcast %d " + "[MTPBf16UpcastBroadcast-v24] Upcast %d " "MTP tensors bf16->fp32 (name=%s).", _upcasted, name, ) @@ -3926,7 +3937,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _h.flush() except Exception: pass - _sys_v23m.stdout.flush() + _sys_v24m.stdout.flush() except Exception: pass # Per-tensor stats. @@ -3954,10 +3965,6 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"abs_max={_abs.max().item():.6e}, " f"norm={_hf_tensor.float().norm().item():.6e}" ) - else: - # non-pp-head rank still drives TP all_gather for - # symmetry. - self._collect_param(name, param) continue if self.config.use_lora and ( ".adapter." not in name or not getattr(param, "requires_grad", False) From 493346109a95ba5705a65a53a03c8d1fc6cf6f96 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 10:11:35 +0800 Subject: [PATCH 093/140] feat(engine): improve --- areal/engine/megatron_engine.py | 324 +++++++++++++++++++++++++------- areal/engine/sglang_remote.py | 22 +++ 2 files changed, 279 insertions(+), 67 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 0bbc69e883..2c208990ae 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -207,6 +207,7 @@ def __init__(self, config: TrainEngineConfig): "v22:CollectiveDeadlockFix+PreScan+EarlyFlush", "v23:NoDPCollective+TPConsensus+MainParamViewDirect", "v24:TPDtypeSymmetric+NonPPHeadAlsoFp32", + "v25:AcceptEMA+GradProbe+PostOptim+Bf16Drift+SendPreBcast", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": @@ -1221,7 +1222,53 @@ def optimizer_step(self): ) with trace_scope("megatron_engine.step"): + # [MTPGradProbe-v25] Diagnostic grad probe before optimizer.step(). + try: + _v25_probe_seen = set() + if getattr(self, "model", None) is not None: + for _v25_module in self.model: + for _v25_name, _v25_p in _v25_module.named_parameters(): + if ".mtp." not in _v25_name: + continue + if _v25_name in _v25_probe_seen: + continue + _v25_probe_seen.add(_v25_name) + _v25_mp = getattr(_v25_p, "main_param", None) + self.logger.info( + "[MTPGradProbe-v25] name=%s has_grad=%s grad_dtype=%s grad.abs_mean=%.3e grad.abs_max=%.3e grad.nonzero_frac=%.3f main_param_dtype=%s main_param_abs_mean=%.3e", + _v25_name, (_v25_p.grad is not None), str(_v25_p.grad.dtype if _v25_p.grad is not None else None), + (_v25_p.grad.abs().mean().item() if _v25_p.grad is not None else float('nan')), + (_v25_p.grad.abs().max().item() if _v25_p.grad is not None else float('nan')), + ((_v25_p.grad != 0).float().mean().item() if _v25_p.grad is not None else float('nan')), + str(_v25_mp.dtype if _v25_mp is not None else None), + (_v25_mp.abs().mean().item() if _v25_mp is not None else float('nan')) + ) + except Exception as _e: + self.logger.warning("[MTPGradProbe-v25] probe error: %s", _e) update_successful, grad_norm, _ = self.optimizer.step() + # [MTPPostOptim-v25] Diagnostic post-optimizer-step probe. + try: + _v25_post_seen = set() + if getattr(self, "model", None) is not None: + for _v25_module in self.model: + for _v25_name, _v25_p in _v25_module.named_parameters(): + if ".mtp." not in _v25_name: + continue + if _v25_name in _v25_post_seen: + continue + _v25_post_seen.add(_v25_name) + _v25_mp = getattr(_v25_p, "main_param", None) + self.logger.info( + "[MTPPostOptim-v25] name=%s main_param_abs_mean_post=%.6e bf16_model_abs_mean=%.6e " + "cast_diff_l1=%.3e cast_diff_linf=%.3e", + _v25_name, + (_v25_mp.abs().mean().item() if _v25_mp is not None else float('nan')), + _v25_p.data.abs().mean().item(), + ((_v25_mp.to(_v25_p.dtype) - _v25_p.data).abs().mean().item() if _v25_mp is not None else float('nan')), + ((_v25_mp.to(_v25_p.dtype) - _v25_p.data).abs().max().item() if _v25_mp is not None else float('nan')), + ) + except Exception as _e: + self.logger.warning("[MTPPostOptim-v25] probe error: %s", _e) # [SpecDecDiag-v20 D11] post-step |deltaW| per MTP tensor. try: @@ -2898,6 +2945,28 @@ def _update_bucket_weights_from_distributed( f"n_tensors={len(converted_named_tensors)}, n_specs={len(param_specs)}, " f"names={[n for n, _ in converted_named_tensors[:5]]}..." ) + # [MTPSendPreBcast-v25] Capture exact tensors to be broadcast. + try: + for _v25_name, _v25_t in converted_named_tensors: + if ("mtp_layers." in _v25_name or ".mtp." in _v25_name): + try: + _v25_first8 = [ + float(x) for x in _v25_t.flatten()[:8].tolist() + ] + except Exception: + _v25_first8 = [] + self.logger.info( + "[MTPSendPreBcast-v25] name=%s dtype=%s shape=%s " + "abs_mean=%.6e abs_max=%.6e first8=%s", + _v25_name, str(_v25_t.dtype), tuple(_v25_t.shape), + _v25_t.abs().mean().item(), + _v25_t.abs().max().item(), + _v25_first8, + ) + except Exception as _e_v25s: + self.logger.warning( + "[MTPSendPreBcast-v25] probe error: %s", _e_v25s, + ) _t_post0 = _diag_time.time() fut = self.rollout_engine.update_weights_from_distributed(meta, param_specs) self.logger.info( @@ -3652,33 +3721,19 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if ".mtp." in name: mtp_param_count += 1 mtp_param_bytes += param.numel() * param.element_size() - # [MTPFp32MasterRead-v24] TPDtypeSymmetric path. - # Root cause of the iter10 hang: v22 issued a DP - # all_gather_into_tensor under `if _have_master` - # which is True only on the DP rank that *owns* - # the fp32 master (DistributedOptimizer assigns - # each param-bucket to ONE DP rank), so non-owning - # DP peers skipped the collective -> NCCL hang. - # PreScan-v22 proved shard_numel == full_numel - # on owning rank, i.e. main_param is ALREADY the - # full TP-shard. Therefore the DP all_gather is - # unnecessary and removed. + # [MTPFp32MasterRead-v24] TP-dtype-symmetric path. # - # TP-safety: DistributedOptimizer ownership is - # per-bucket by DP rank, so all TP peers of the - # owning DP rank share the same ownership status. - # A TP-group MIN all_reduce on the have_master - # bool provides a belt-and-braces sanity check. + # Root cause of v23 hang at eh_proj.weight: + # rank 0 (pp-head,dp=0,tp=0): _collect_param(fp32) + # rank 1 (non-pp, dp=0,tp=1): _collect_param(bf16) + # _collect_param internally does TP all_gather_param; + # two TP peers feeding different dtypes -> NCCL hang. # - # v24 change: ALL ranks (pp-head or not) call - # _collect_param with the SAME source tensor - # (fp32_full or bf16 param) so the TP - # all_gather_param inside _collect_param sees a - # consistent dtype across TP peers. v23 had a - # subtle bug: pp-head passed fp32 while non-pp - # passed bf16 -> dtype mismatch inside - # all_gather_param -> silent data corruption or - # NCCL dtype error. + # v24 uniformly builds _fp32_full on the owning DP + # rank regardless of pp-head status, and ALWAYS passes + # the SAME dtype tensor to _collect_param on every TP + # peer of that TP group. Non-pp-head peer drops the + # returned collected tensor. import os as _os_v24m import sys as _sys_v24m import torch as _torch_v24m @@ -3698,7 +3753,6 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: and int(_mp_shard.numel()) == int(param.numel()) ) - # Fetch group handles (DP info only for log). try: _dp_group = mpu.get_data_parallel_group( with_context_parallel=True, @@ -3719,10 +3773,12 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) self.logger.info( "[MTPFp32MasterRead-v24 ENTER] rank=%d " - "name=%s dp_ws=%d tp_ws=%d " + "name=%s pp_head=%s dp_ws=%d tp_ws=%d " "have_master_local=%s shard_numel=%s " "need_numel=%d", - dist.get_rank(), name, _dp_ws, _tp_ws, + dist.get_rank(), name, + str(_collect_mtp_for_draft), + _dp_ws, _tp_ws, str(_have_master_local), (str(int(_mp_shard.numel())) if isinstance( @@ -3740,9 +3796,8 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: except Exception: pass # TP-group MIN all_reduce on have_master bool. - # All TP peers within the same TP group MUST - # call this collective together (they do: we - # are outside any `if _have_master` gate). + # Runs on ALL TP peers (outside any gate), so + # dtype-symmetric int32 tensors join. _have_master_tp = _have_master_local if _tp_group is not None and _tp_ws > 1: _dev = ( @@ -3764,30 +3819,22 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ), group=_tp_group, ) - _have_master_tp = bool(int(_hv.item()) == 1) - if _have_master_tp != _have_master_local: - self.logger.warning( - "[MTPFp32MasterRead-v24 " - "TP_CONSENSUS] rank=%d name=%s " - "local=%s consensus=%s; " - "falling back to bf16 on this " - "TP peer to keep _collect_param " - "symmetric.", - dist.get_rank(), name, - str(_have_master_local), - str(_have_master_tp), - ) - else: - self.logger.info( - "[MTPFp32MasterRead-v24 " - "TP_CONSENSUS] rank=%d name=%s " - "consensus=%s", - dist.get_rank(), name, - str(_have_master_tp), - ) + _have_master_tp = bool( + int(_hv.item()) == 1 + ) + self.logger.info( + "[MTPFp32MasterRead-v24 " + "TP_CONSENSUS] rank=%d name=%s " + "local=%s consensus=%s", + dist.get_rank(), name, + str(_have_master_local), + str(_have_master_tp), + ) + # Build fp32_full ONLY if BOTH local and TP + # consensus say yes. This way every TP peer + # ends up with _fp32_full==None XOR fp32, + # consistently across the TP group. if _have_master_tp and _have_master_local: - # Direct view: main_param is already the - # full TP-shard (shard_numel==full_numel). _fp32_full = ( _mp_shard.view(param.shape) .contiguous() @@ -3805,7 +3852,71 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: param.partition_stride ) _src_tag = "fp32master" + elif _have_master_tp and not _have_master_local: + # Shouldn't happen by DistributedOpt + # semantics, but guard: TP peer says yes, + # we say no -> we MUST still produce fp32 + # tensor to stay dtype-symmetric. Allocate + # a zero fp32 tensor of param.shape as a + # placeholder; the gathered output will be + # wrong on our slice but only the pp-head + # rank consumes the gather output, and + # the TP peer that DOES have master sends + # the correct slice. Wait -- _collect_param + # gathers every TP slice; if our slice is + # zero that taints the gathered tensor. + # In practice this branch is unreachable + # given ownership; downgrade to bf16-all. + self.logger.warning( + "[MTPFp32MasterRead-v24 " + "TP_CONSENSUS_ASYMM] rank=%d name=%s " + "consensus=True local=False; falling " + "back to bf16 on ENTIRE TP group to " + "avoid tainting gathered slice.", + dist.get_rank(), name, + ) + _fp32_full = None + _src_tag = "bf16model" + # Propagate decision via a second tiny + # all_reduce so the PEER sees it too. + _force_bf16 = _torch_v24m.tensor( + [1], + dtype=_torch_v24m.int32, + device=param.device, + ) + _torch_v24m.distributed.all_reduce( + _force_bf16, + op=( + _torch_v24m.distributed + .ReduceOp.MAX + ), + group=_tp_group, + ) else: + # Both TP peers agree: no master. + # Must also run the second all_reduce + # so symmetry with the asymm branch is + # maintained. + _force_bf16 = _torch_v24m.tensor( + [0], + dtype=_torch_v24m.int32, + device=param.device, + ) + _torch_v24m.distributed.all_reduce( + _force_bf16, + op=( + _torch_v24m.distributed + .ReduceOp.MAX + ), + group=_tp_group, + ) + if int(_force_bf16.item()) == 1: + self.logger.warning( + "[MTPFp32MasterRead-v24 " + "TP_CONSENSUS_ASYMM] rank=%d " + "name=%s forced bf16 by peer.", + dist.get_rank(), name, + ) if not getattr( self, "_mtp_master_read_missing_warned", @@ -3824,10 +3935,13 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if _fp32_full is not None: self.logger.info( "[MTPFp32MasterRead-v24 D15a] " - "rank=%d name=%s dp_ws=%d tp_ws=%d " - "shape=%s fp32_abs_mean=%.6e " + "rank=%d name=%s pp_head=%s " + "dp_ws=%d tp_ws=%d shape=%s " + "fp32_abs_mean=%.6e " "fp32_abs_max=%.6e (source=%s)", - dist.get_rank(), name, _dp_ws, _tp_ws, + dist.get_rank(), name, + str(_collect_mtp_for_draft), + _dp_ws, _tp_ws, tuple(_fp32_full.shape), float(_fp32_full.abs().mean().item()), float(_fp32_full.abs().max().item()), @@ -3850,16 +3964,56 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) _fp32_full = None _src_tag = "bf16model" - # v24: ALL ranks call _collect_param with the SAME - # source tensor (fp32_full or bf16 param) so the TP - # all_gather_param inside sees consistent dtype. - if _fp32_full is not None: - _mtp_param, _ = self._collect_param( - name, _fp32_full, - ) - else: - _mtp_param, _ = self._collect_param(name, param) - if _collect_mtp_for_draft: + # === v24 key change === + # Hand the SAME dtype tensor to _collect_param on + # every TP peer. pp-head consumes the gathered + # tensor; non-pp-head drops it. + _collect_src = ( + _fp32_full if _fp32_full is not None else param + ) + self.logger.info( + "[MTPFp32MasterRead-v24 COLLECT_SRC] rank=%d " + "name=%s pp_head=%s src_dtype=%s src_shape=%s " + "src_tag=%s", + dist.get_rank(), name, + str(_collect_mtp_for_draft), + str(_collect_src.dtype), + tuple(_collect_src.shape), + _src_tag, + ) + try: + for _h in list(self.logger.handlers): + try: + _h.flush() + except Exception: + pass + _sys_v24m.stdout.flush() + except Exception: + pass + _mtp_param, _ = self._collect_param( + name, _collect_src, + ) + self.logger.info( + "[MTPFp32MasterRead-v24 COLLECT_DONE] rank=%d " + "name=%s pp_head=%s returned_dtype=%s " + "returned_shape=%s", + dist.get_rank(), name, + str(_collect_mtp_for_draft), + (str(_mtp_param.dtype) + if _mtp_param is not None else "None"), + (tuple(_mtp_param.shape) + if _mtp_param is not None else "None"), + ) + try: + for _h in list(self.logger.handlers): + try: + _h.flush() + except Exception: + pass + _sys_v24m.stdout.flush() + except Exception: + pass + if _collect_mtp_for_draft and _mtp_param is not None: _mtp_model_name = self.hf_config.model_type _prev_count = len(mtp_hf_tensors) mtp_hf_tensors.extend( @@ -3919,6 +4073,42 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: self._mtp_d15_prev_abs_mean[_hf_nm_d15] = ( _am_d15 ) + # [MTPBf16Drift-v25] fp32 vs bf16-cast drift. + try: + import torch as _torch_v25d + _fp32_dtype = _hf_tn_d15.dtype + _fp32_ref = _hf_tn_d15.float() + _fp32_abs_mean = float( + _fp32_ref.abs().mean().item() + ) + try: + _bf16_cast = _hf_tn_d15.to( + _torch_v25d.bfloat16 + ).float() + _bf16_abs_mean = float( + _bf16_cast.abs().mean().item() + ) + _diff = (_fp32_ref - _bf16_cast).abs() + _diff_l1 = float(_diff.mean().item()) + _diff_linf = float(_diff.max().item()) + except Exception: + _bf16_abs_mean = float('nan') + _diff_l1 = float('nan') + _diff_linf = float('nan') + self.logger.info( + "[MTPBf16Drift-v25] hf=%s " + "fp32_abs_mean=%.6e bf16_abs_mean=%.6e " + "cast_diff_l1=%.3e cast_diff_linf=%.3e " + "fp32_dtype=%s", + _hf_nm_d15, _fp32_abs_mean, + _bf16_abs_mean, _diff_l1, + _diff_linf, str(_fp32_dtype), + ) + except Exception as _e_v25d: + self.logger.warning( + "[MTPBf16Drift-v25] error hf=%s: %s", + _hf_nm_d15, _e_v25d, + ) self.logger.info( "[MTPWeightDeltaD15] hf=%s " "abs_mean=%.9e delta=%s frozen=%s " diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 037b23cc6c..7abe90e7f9 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -137,6 +137,28 @@ def parse_generation_response( f"completion_tokens=" f"{meta_info.get('completion_tokens', 'n/a')}" ) + # [SpecDecDiag-v20 D03+] Accept-rate EMA on class instance. + try: + rate = accept_rate + if not hasattr(self, '_spec_dec_ema_short'): + self._spec_dec_ema_short = rate + self._spec_dec_ema_long = rate + self._spec_dec_rate_count = 1 + self._spec_dec_rate_sum = rate + else: + alpha_s = 2.0 / (64 + 1) + alpha_l = 2.0 / (256 + 1) + self._spec_dec_ema_short = alpha_s * rate + (1 - alpha_s) * self._spec_dec_ema_short + self._spec_dec_ema_long = alpha_l * rate + (1 - alpha_l) * self._spec_dec_ema_long + self._spec_dec_rate_count += 1 + self._spec_dec_rate_sum += rate + logger.info( + "[SpecDecDiag-v20 D03+] AcceptEMA: rate=%.4f ema64=%.4f ema256=%.4f global_n=%d global_avg=%.4f", + rate, self._spec_dec_ema_short, self._spec_dec_ema_long, + self._spec_dec_rate_count, self._spec_dec_rate_sum / self._spec_dec_rate_count, + ) + except Exception as _e: + pass if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): return HttpGenerationResult( output_tokens=[], From 4b9a8e784cc981a1b2dca8c340a5b51cc329ab67 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 11:37:10 +0800 Subject: [PATCH 094/140] feat(megatron_engine): verify fp32 weight --- areal/engine/megatron_engine.py | 128 +++++++++++++++++++++++++++++++- 1 file changed, 127 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 2c208990ae..af7d6731f0 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -220,8 +220,9 @@ def __init__(self, config: TrainEngineConfig): _os_banner.environ.get( "AREAL_MTP_FP32_MASTER_READ", "1"), } + # [MTPVersionBanner-v26] Added iter14 instrumentation: MTPSerializeSendMTP-v26 / MTPGradProbe-v26 / SGLangReadBackMTP-v26. self.logger.info( - "[MTPVersionBanner] tags=%s flags=%s", + "[MTPVersionBanner-v26] tags=%s flags=%s", ",".join(_banner_tags), _banner_flags, ) try: @@ -1222,6 +1223,62 @@ def optimizer_step(self): ) with trace_scope("megatron_engine.step"): + # [MTPGradProbe-v26] Install post-accumulate-grad hook on MTP + # params (once) so grads are captured at the moment they land, + # BEFORE Megatron's DistributedOptimizer consumes and frees them. + try: + if not getattr(self, "_mtp_gradhook_v26_installed", False): + self._mtp_gradhook_v26_cache = {} + _v26_inst = 0 + for _v26_m in self.model: + for _v26_n, _v26_p in _v26_m.named_parameters(): + if ( + (".mtp_layers." in _v26_n + or ".mtp." in _v26_n + or ".enorm" in _v26_n + or ".hnorm" in _v26_n + or ".eh_proj" in _v26_n) + and _v26_p.requires_grad + ): + def _mk_hook(_nm): + def _hook(_p): + try: + if _p.grad is not None: + self._mtp_gradhook_v26_cache[_nm] = ( + float(_p.grad.abs().mean().item()), + float(_p.grad.abs().max().item()), + str(_p.grad.dtype), + ) + except Exception: + pass + return _hook + try: + _v26_p.register_post_accumulate_grad_hook( + _mk_hook(_v26_n) + ) + _v26_inst += 1 + except Exception: + pass + self._mtp_gradhook_v26_installed = True + self.logger.info( + "[MTPGradProbe-v26] installed post_accumulate_grad_hook " + "on %d MTP params", + _v26_inst, + ) + if getattr(self, "_mtp_gradhook_v26_cache", None): + for _v26_nm, (_am, _mx, _dt) in ( + self._mtp_gradhook_v26_cache.items() + ): + self.logger.info( + "[MTPGradProbe-v26] name=%s grad_abs_mean=%.3e " + "grad_abs_max=%.3e grad_dtype=%s", + _v26_nm, _am, _mx, _dt, + ) + self._mtp_gradhook_v26_cache = {} + except Exception as _e_v26g: + self.logger.warning( + "[MTPGradProbe-v26] outer error: %s", _e_v26g, + ) # [MTPGradProbe-v25] Diagnostic grad probe before optimizer.step(). try: _v25_probe_seen = set() @@ -3340,6 +3397,32 @@ def _serialize_mtp_tensors_for_update( f"tensor_dtypes={_tensor_dtypes}, " f"tensor_sizes_bytes={_tensor_sizes}" ) + # [MTPSerializeSendMTP-v26] Sample first 8 values of each MTP + # tensor so we can prove the actual bytes placed into the + # SGLang IPC payload. The earlier MTPSendPreBcast-v25 probe + # was installed on the /update_weights_from_distributed bucket + # path which MTP tensors bypass — explaining 0 events in log.7. + try: + for _v26_name, _v26_t in mtp_hf_tensors: + try: + _v26_first8 = [ + float(x) for x in _v26_t.flatten()[:8].tolist() + ] + except Exception: + _v26_first8 = [] + self.logger.info( + "[MTPSerializeSendMTP-v26] name=%s dtype=%s shape=%s " + "abs_mean=%.6e abs_max=%.6e first8=%s", + _v26_name, str(_v26_t.dtype), tuple(_v26_t.shape), + float(_v26_t.abs().mean().item()), + float(_v26_t.abs().max().item()), + _v26_first8, + ) + except Exception as _e_v26s: + self.logger.warning( + "[MTPSerializeSendMTP-v26] probe error: %s", _e_v26s, + ) + # ------------------------------------------------------------------- # GPU -> CPU copy on a *dedicated CUDA stream* that is insulated @@ -4486,6 +4569,49 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: serialized_payload=serialized_payload, flush_cache=True, ) + # Read back MTP LayerNorm weights + # from the SGLang draft model to verify what bytes the + # speculative decoder actually holds. This closes the + # loop: trainer -> serialize -> SGLang. + try: + if hasattr( + self.rollout_engine, "get_weights_by_parameter_name" + ): + _v26_probe_names = [ + "model.mtp_layers.0.token_layernorm.weight", + "model.mtp_layers.0.hidden_layernorm.weight", + "model.mtp_layers.0.input_layernorm.weight", + "model.mtp_layers.0.post_attention_layernorm.weight", + "model.mtp_layers.0.final_layernorm.weight", + ] + for _v26_pn in _v26_probe_names: + try: + _v26_rb = ( + self.rollout_engine + .get_weights_by_parameter_name( + _v26_pn, truncate_size=8, + ) + ) + self.logger.info( + "[SGLangReadBackMTP-v26] name=%s first8=%s", + _v26_pn, _v26_rb, + ) + except Exception as _e_v26rb1: + self.logger.info( + "[SGLangReadBackMTP-v26] name=%s " + "readback_unavailable err=%s", + _v26_pn, _e_v26rb1, + ) + else: + self.logger.info( + "[SGLangReadBackMTP-v26] rollout_engine lacks " + "get_weights_by_parameter_name; cannot read back." + ) + except Exception as _e_v26rb: + self.logger.warning( + "[SGLangReadBackMTP-v26] outer error: %s", + _e_v26rb, + ) _t_call1 = _time.time() self.logger.info( f"[DiagUW] Successfully updated EAGLE draft model " From 3ea975f10be4b80b6046df53fffc0b72c6272695 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 13:44:09 +0800 Subject: [PATCH 095/140] feat(engine): fix again --- areal/engine/megatron_engine.py | 350 ++++++++++++++++++++------------ 1 file changed, 225 insertions(+), 125 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index af7d6731f0..ec6d0b0d06 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -201,7 +201,6 @@ def __init__(self, config: TrainEngineConfig): "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", - "v20:SpecDecDiag(D01-D14 full pipeline instrumentation)", "v21:MTPFp32MasterRead+DefaultsOn" "(AREAL_MTP_FP32_MASTER_READ,AREAL_MTP_FP32_BROADCAST=1)", "v22:CollectiveDeadlockFix+PreScan+EarlyFlush", @@ -220,9 +219,9 @@ def __init__(self, config: TrainEngineConfig): _os_banner.environ.get( "AREAL_MTP_FP32_MASTER_READ", "1"), } - # [MTPVersionBanner-v26] Added iter14 instrumentation: MTPSerializeSendMTP-v26 / MTPGradProbe-v26 / SGLangReadBackMTP-v26. + # [MTPVersionBanner-v27] Added iter14 instrumentation: MTPSerializeSendMTP-v26 / MTPGradProbe-v26 / SGLangReadBackMTP-v26. self.logger.info( - "[MTPVersionBanner-v26] tags=%s flags=%s", + "[MTPVersionBanner-v27] tags=%s flags=%s", ",".join(_banner_tags), _banner_flags, ) try: @@ -1222,110 +1221,158 @@ def optimizer_step(self): "[SpecDecDiag-v20 D10] snapshot failed: %s", _e_d10, ) - with trace_scope("megatron_engine.step"): - # [MTPGradProbe-v26] Install post-accumulate-grad hook on MTP - # params (once) so grads are captured at the moment they land, - # BEFORE Megatron's DistributedOptimizer consumes and frees them. - try: - if not getattr(self, "_mtp_gradhook_v26_installed", False): - self._mtp_gradhook_v26_cache = {} - _v26_inst = 0 - for _v26_m in self.model: - for _v26_n, _v26_p in _v26_m.named_parameters(): - if ( - (".mtp_layers." in _v26_n - or ".mtp." in _v26_n - or ".enorm" in _v26_n - or ".hnorm" in _v26_n - or ".eh_proj" in _v26_n) - and _v26_p.requires_grad - ): - def _mk_hook(_nm): - def _hook(_p): - try: - if _p.grad is not None: - self._mtp_gradhook_v26_cache[_nm] = ( - float(_p.grad.abs().mean().item()), - float(_p.grad.abs().max().item()), - str(_p.grad.dtype), - ) - except Exception: - pass - return _hook - try: - _v26_p.register_post_accumulate_grad_hook( - _mk_hook(_v26_n) - ) - _v26_inst += 1 - except Exception: - pass - self._mtp_gradhook_v26_installed = True + # [MTPGradProbe-v26] Install post-accumulate-grad hook on MTP + # params (once) so grads are captured at the moment they land, + # BEFORE Megatron's DistributedOptimizer consumes and frees them. + try: + if not getattr(self, "_mtp_gradhook_v26_installed", False): + self._mtp_gradhook_v26_cache = {} + _v26_inst = 0 + for _v26_m in self.model: + for _v26_n, _v26_p in _v26_m.named_parameters(): + if ( + (".mtp_layers." in _v26_n + or ".mtp." in _v26_n + or ".enorm" in _v26_n + or ".hnorm" in _v26_n + or ".eh_proj" in _v26_n) + and _v26_p.requires_grad + ): + def _mk_hook(_nm): + def _hook(_p): + try: + if _p.grad is not None: + self._mtp_gradhook_v26_cache[_nm] = ( + float(_p.grad.abs().mean().item()), + float(_p.grad.abs().max().item()), + str(_p.grad.dtype), + ) + except Exception: + pass + return _hook + try: + _v26_p.register_post_accumulate_grad_hook( + _mk_hook(_v26_n) + ) + _v26_inst += 1 + except Exception: + pass + self._mtp_gradhook_v26_installed = True + self.logger.info( + "[MTPGradProbe-v26] installed post_accumulate_grad_hook " + "on %d MTP params", + _v26_inst, + ) + if getattr(self, "_mtp_gradhook_v26_cache", None): + for _v26_nm, (_am, _mx, _dt) in ( + self._mtp_gradhook_v26_cache.items() + ): self.logger.info( - "[MTPGradProbe-v26] installed post_accumulate_grad_hook " - "on %d MTP params", - _v26_inst, + "[MTPGradProbe-v26] name=%s grad_abs_mean=%.3e " + "grad_abs_max=%.3e grad_dtype=%s", + _v26_nm, _am, _mx, _dt, ) - if getattr(self, "_mtp_gradhook_v26_cache", None): - for _v26_nm, (_am, _mx, _dt) in ( - self._mtp_gradhook_v26_cache.items() + self._mtp_gradhook_v26_cache = {} + except Exception as _e_v26g: + self.logger.warning( + "[MTPGradProbe-v26] outer error: %s", _e_v26g, + ) + # [MTPMainGrad-v27] Log Megatron DistributedOptimizer's + # fp32 reduced gradient buffer (param.main_grad) just before + # optimizer.step(). This is the ACTUAL gradient the optimizer + # consumes (fp32, post-allreduce, post-inv-scale), not the raw + # bf16 .grad captured by the grad hook. Comparing the two + # distinguishes "grad vanishes in backward" vs "grad vanishes + # in allreduce/scaling pipeline". + try: + for _v27_m in self.model: + for _v27_n, _v27_p in _v27_m.named_parameters(): + if not ( + ".mtp_layers." in _v27_n + or ".mtp." in _v27_n + or ".enorm" in _v27_n + or ".hnorm" in _v27_n + or ".eh_proj" in _v27_n ): - self.logger.info( - "[MTPGradProbe-v26] name=%s grad_abs_mean=%.3e " - "grad_abs_max=%.3e grad_dtype=%s", - _v26_nm, _am, _mx, _dt, - ) - self._mtp_gradhook_v26_cache = {} - except Exception as _e_v26g: - self.logger.warning( - "[MTPGradProbe-v26] outer error: %s", _e_v26g, - ) - # [MTPGradProbe-v25] Diagnostic grad probe before optimizer.step(). - try: - _v25_probe_seen = set() - if getattr(self, "model", None) is not None: - for _v25_module in self.model: - for _v25_name, _v25_p in _v25_module.named_parameters(): - if ".mtp." not in _v25_name: - continue - if _v25_name in _v25_probe_seen: - continue - _v25_probe_seen.add(_v25_name) - _v25_mp = getattr(_v25_p, "main_param", None) + continue + try: + _v27_mg = getattr(_v27_p, "main_grad", None) + if _v27_mg is None: self.logger.info( - "[MTPGradProbe-v25] name=%s has_grad=%s grad_dtype=%s grad.abs_mean=%.3e grad.abs_max=%.3e grad.nonzero_frac=%.3f main_param_dtype=%s main_param_abs_mean=%.3e", - _v25_name, (_v25_p.grad is not None), str(_v25_p.grad.dtype if _v25_p.grad is not None else None), - (_v25_p.grad.abs().mean().item() if _v25_p.grad is not None else float('nan')), - (_v25_p.grad.abs().max().item() if _v25_p.grad is not None else float('nan')), - ((_v25_p.grad != 0).float().mean().item() if _v25_p.grad is not None else float('nan')), - str(_v25_mp.dtype if _v25_mp is not None else None), - (_v25_mp.abs().mean().item() if _v25_mp is not None else float('nan')) + "[MTPMainGrad-v27] name=%s main_grad=None " + ".grad_is_none=%s", + _v27_n, _v27_p.grad is None, ) - except Exception as _e: - self.logger.warning("[MTPGradProbe-v25] probe error: %s", _e) - update_successful, grad_norm, _ = self.optimizer.step() - # [MTPPostOptim-v25] Diagnostic post-optimizer-step probe. - try: - _v25_post_seen = set() - if getattr(self, "model", None) is not None: - for _v25_module in self.model: - for _v25_name, _v25_p in _v25_module.named_parameters(): - if ".mtp." not in _v25_name: - continue - if _v25_name in _v25_post_seen: - continue - _v25_post_seen.add(_v25_name) - _v25_mp = getattr(_v25_p, "main_param", None) + else: self.logger.info( - "[MTPPostOptim-v25] name=%s main_param_abs_mean_post=%.6e bf16_model_abs_mean=%.6e " - "cast_diff_l1=%.3e cast_diff_linf=%.3e", - _v25_name, - (_v25_mp.abs().mean().item() if _v25_mp is not None else float('nan')), - _v25_p.data.abs().mean().item(), - ((_v25_mp.to(_v25_p.dtype) - _v25_p.data).abs().mean().item() if _v25_mp is not None else float('nan')), - ((_v25_mp.to(_v25_p.dtype) - _v25_p.data).abs().max().item() if _v25_mp is not None else float('nan')), + "[MTPMainGrad-v27] name=%s dtype=%s " + "shape=%s abs_mean=%.3e abs_max=%.3e " + "nonzero_frac=%.3f", + _v27_n, str(_v27_mg.dtype), + tuple(_v27_mg.shape), + float(_v27_mg.abs().mean().item()), + float(_v27_mg.abs().max().item()), + float( + (_v27_mg.abs() > 0).float().mean().item() + ), ) - except Exception as _e: - self.logger.warning("[MTPPostOptim-v25] probe error: %s", _e) + except Exception as _e_v27mg1: + self.logger.info( + "[MTPMainGrad-v27] name=%s probe_err=%s", + _v27_n, _e_v27mg1, + ) + except Exception as _e_v27mg: + self.logger.warning( + "[MTPMainGrad-v27] outer error: %s", _e_v27mg, + ) + # [MTPGradProbe-v25] Diagnostic grad probe before optimizer.step(). + try: + _v25_probe_seen = set() + if getattr(self, "model", None) is not None: + for _v25_module in self.model: + for _v25_name, _v25_p in _v25_module.named_parameters(): + if ".mtp." not in _v25_name: + continue + if _v25_name in _v25_probe_seen: + continue + _v25_probe_seen.add(_v25_name) + _v25_mp = getattr(_v25_p, "main_param", None) + self.logger.info( + "[MTPGradProbe-v25] name=%s has_grad=%s grad_dtype=%s grad.abs_mean=%.3e grad.abs_max=%.3e grad.nonzero_frac=%.3f main_param_dtype=%s main_param_abs_mean=%.3e", + _v25_name, (_v25_p.grad is not None), str(_v25_p.grad.dtype if _v25_p.grad is not None else None), + (_v25_p.grad.abs().mean().item() if _v25_p.grad is not None else float('nan')), + (_v25_p.grad.abs().max().item() if _v25_p.grad is not None else float('nan')), + ((_v25_p.grad != 0).float().mean().item() if _v25_p.grad is not None else float('nan')), + str(_v25_mp.dtype if _v25_mp is not None else None), + (_v25_mp.abs().mean().item() if _v25_mp is not None else float('nan')) + ) + except Exception as _e: + self.logger.warning("[MTPGradProbe-v25] probe error: %s", _e) + with trace_scope("megatron_engine.step"): + update_successful, grad_norm, _ = self.optimizer.step() + # [MTPPostOptim-v25] Diagnostic post-optimizer-step probe. + try: + _v25_post_seen = set() + if getattr(self, "model", None) is not None: + for _v25_module in self.model: + for _v25_name, _v25_p in _v25_module.named_parameters(): + if ".mtp." not in _v25_name: + continue + if _v25_name in _v25_post_seen: + continue + _v25_post_seen.add(_v25_name) + _v25_mp = getattr(_v25_p, "main_param", None) + self.logger.info( + "[MTPPostOptim-v25] name=%s main_param_abs_mean_post=%.6e bf16_model_abs_mean=%.6e " + "cast_diff_l1=%.3e cast_diff_linf=%.3e", + _v25_name, + (_v25_mp.abs().mean().item() if _v25_mp is not None else float('nan')), + _v25_p.data.abs().mean().item(), + ((_v25_mp.to(_v25_p.dtype) - _v25_p.data).abs().mean().item() if _v25_mp is not None else float('nan')), + ((_v25_mp.to(_v25_p.dtype) - _v25_p.data).abs().max().item() if _v25_mp is not None else float('nan')), + ) + except Exception as _e: + self.logger.warning("[MTPPostOptim-v25] probe error: %s", _e) # [SpecDecDiag-v20 D11] post-step |deltaW| per MTP tensor. try: @@ -4156,6 +4203,38 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: self._mtp_d15_prev_abs_mean[_hf_nm_d15] = ( _am_d15 ) + # [MTPFp32Delta-v27] Track fp32 master abs_mean delta + # between consecutive MTP sync events. Combined with + # MTPBf16Drift-v25, makes it possible to compare + # "fp32 per-step update" vs "bf16 ULP" directly. + try: + if not hasattr(self, "_mtp_v27_fp32_prev"): + self._mtp_v27_fp32_prev = {} + _v27_fp32_am = float( + _hf_tn_d15.float().abs().mean().item() + ) + _v27_prev = self._mtp_v27_fp32_prev.get( + _hf_nm_d15 + ) + _v27_delta = ( + None if _v27_prev is None + else _v27_fp32_am - _v27_prev + ) + self._mtp_v27_fp32_prev[_hf_nm_d15] = ( + _v27_fp32_am + ) + self.logger.info( + "[MTPFp32Delta-v27] hf=%s " + "fp32_abs_mean=%.9e delta=%s", + _hf_nm_d15, _v27_fp32_am, + ("%+0.3e" % _v27_delta) + if _v27_delta is not None else "n/a", + ) + except Exception as _e_v27fd: + self.logger.warning( + "[MTPFp32Delta-v27] err hf=%s: %s", + _hf_nm_d15, _e_v27fd, + ) # [MTPBf16Drift-v25] fp32 vs bf16-cast drift. try: import torch as _torch_v25d @@ -4569,48 +4648,69 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: serialized_payload=serialized_payload, flush_cache=True, ) - # Read back MTP LayerNorm weights - # from the SGLang draft model to verify what bytes the - # speculative decoder actually holds. This closes the - # loop: trainer -> serialize -> SGLang. + # [SGLangReadBackMTPv2-v27] Read back MTP LayerNorm weights + # from the SGLang draft model over HTTP directly. + # iter14 used Python attribute access (missing on + # RemoteSGLangEngine). SGLang exposes + # /get_weights_by_parameter_name endpoint (introduced + # for verl/slime) which accepts JSON {name, truncate_size}. try: - if hasattr( - self.rollout_engine, "get_weights_by_parameter_name" - ): - _v26_probe_names = [ + import requests as _v27_requests + _v27_addrs = None + try: + _v27_inner = getattr( + self.rollout_engine, "_engine", None + ) + _v27_addrs = getattr(_v27_inner, "addresses", None) + except Exception: + _v27_addrs = None + if _v27_addrs: + _v27_probe = [ "model.mtp_layers.0.token_layernorm.weight", "model.mtp_layers.0.hidden_layernorm.weight", "model.mtp_layers.0.input_layernorm.weight", "model.mtp_layers.0.post_attention_layernorm.weight", "model.mtp_layers.0.final_layernorm.weight", ] - for _v26_pn in _v26_probe_names: + _v27_addr0 = _v27_addrs[0] + _v27_base = ( + _v27_addr0 if _v27_addr0.startswith("http") + else f"http://{_v27_addr0}" + ) + for _v27_nm in _v27_probe: try: - _v26_rb = ( - self.rollout_engine - .get_weights_by_parameter_name( - _v26_pn, truncate_size=8, - ) + _v27_resp = _v27_requests.post( + f"{_v27_base}/get_weights_by_parameter_name", + json={ + "name": _v27_nm, + "truncate_size": 8, + }, + timeout=15, ) + _v27_body = _v27_resp.text[:400] self.logger.info( - "[SGLangReadBackMTP-v26] name=%s first8=%s", - _v26_pn, _v26_rb, + "[SGLangReadBackMTPv2-v27] name=%s " + "status=%s body=%s", + _v27_nm, _v27_resp.status_code, + _v27_body, ) - except Exception as _e_v26rb1: + except Exception as _e_v27rb1: self.logger.info( - "[SGLangReadBackMTP-v26] name=%s " - "readback_unavailable err=%s", - _v26_pn, _e_v26rb1, + "[SGLangReadBackMTPv2-v27] name=%s " + "http_err=%s", _v27_nm, _e_v27rb1, ) else: self.logger.info( - "[SGLangReadBackMTP-v26] rollout_engine lacks " - "get_weights_by_parameter_name; cannot read back." + "[SGLangReadBackMTPv2-v27] addresses unavailable; " + "cannot read back (inner_engine=%s).", + type( + getattr(self.rollout_engine, "_engine", None) + ).__name__, ) - except Exception as _e_v26rb: + except Exception as _e_v27rb: self.logger.warning( - "[SGLangReadBackMTP-v26] outer error: %s", - _e_v26rb, + "[SGLangReadBackMTPv2-v27] outer error: %s", + _e_v27rb, ) _t_call1 = _time.time() self.logger.info( From d76a4ff86bbf1ea186a29d9b854ab8a8ba500791 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 15:19:28 +0800 Subject: [PATCH 096/140] feat(infra): read sglang weight for verify --- areal/engine/megatron_engine.py | 97 +++++++++++++++++- areal/infra/controller/rollout_callback.py | 26 +++++ areal/infra/controller/rollout_controller.py | 101 +++++++++++++++++++ areal/infra/remote_inf_engine.py | 11 ++ 4 files changed, 233 insertions(+), 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ec6d0b0d06..519bebe161 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -219,9 +219,10 @@ def __init__(self, config: TrainEngineConfig): _os_banner.environ.get( "AREAL_MTP_FP32_MASTER_READ", "1"), } - # [MTPVersionBanner-v27] Added iter14 instrumentation: MTPSerializeSendMTP-v26 / MTPGradProbe-v26 / SGLangReadBackMTP-v26. + # [MTPVersionBanner-v28] Added iter16 instrumentation: SGLangReadBackMTPv3-v28(CallbackPath) / MTPBf16ULPProof-v28. + _banner_tags = list(_banner_tags) + ["v27:SGLangReadBackMTPv2-HTTP+MainGrad+Fp32Delta", "v28:SGLangReadBackMTPv3-CallbackPath+MTPBf16ULPProof"] self.logger.info( - "[MTPVersionBanner-v27] tags=%s flags=%s", + "[MTPVersionBanner-v28] tags=%s flags=%s", ",".join(_banner_tags), _banner_flags, ) try: @@ -3465,6 +3466,49 @@ def _serialize_mtp_tensors_for_update( float(_v26_t.abs().max().item()), _v26_first8, ) + # [MTPBf16ULPProof-v28] Prove/disprove bf16 ULP flooring on receiver side. + try: + import torch as _torch_v28p + if not hasattr(self, "_mtp_v28_prev_bf16cast"): + self._mtp_v28_prev_bf16cast = {} + _v28_tf = _v26_t.float() + _v28_bf16 = _v28_tf.to(_torch_v28p.bfloat16) + _v28_bb = _v28_bf16.float() + _v28_eqcast = int((_v28_tf == _v28_bb).sum().item()) + _v28_numel = int(_v28_tf.numel()) + _v28_frac = _v28_eqcast / max(1, _v28_numel) + _v28_prev = self._mtp_v28_prev_bf16cast.get( + _v26_name + ) + if _v28_prev is None: + _v28_unchanged = None + else: + try: + if _v28_prev.shape == _v28_bb.shape: + _v28_unchanged = int( + (_v28_bb == _v28_prev).sum().item() + ) + else: + _v28_unchanged = -2 + except Exception: + _v28_unchanged = -1 + self._mtp_v28_prev_bf16cast[_v26_name] = ( + _v28_bb.detach().clone() + ) + self.logger.info( + "[MTPBf16ULPProof-v28] name=%s numel=%d " + "fp32_eq_bf16cast=%d (frac=%.4f) " + "bf16cast_eq_prev_bf16cast=%s", + _v26_name, _v28_numel, _v28_eqcast, + _v28_frac, + ("n/a" if _v28_unchanged is None + else str(_v28_unchanged)), + ) + except Exception as _e_v28p: + self.logger.warning( + "[MTPBf16ULPProof-v28] error name=%s: %s", + _v26_name, _e_v28p, + ) except Exception as _e_v26s: self.logger.warning( "[MTPSerializeSendMTP-v26] probe error: %s", _e_v26s, @@ -4712,6 +4756,55 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: "[SGLangReadBackMTPv2-v27] outer error: %s", _e_v27rb, ) + # [SGLangReadBackMTPv3-v28] Callback-chain readback. + # In AReaL single-controller mode, self.rollout_engine + # is a RolloutCallback. v27 used + # RemoteSGLangEngine._engine.addresses, which only + # exists on inference-side workers, not on the + # train-side MegatronEngine (log.9 proved it). v28 + # walks callback -> controller -> worker chain that + # already exists for /callback/update_weights_tensor. + try: + _v28_probe_names = [ + "model.mtp_layers.0.token_layernorm.weight", + "model.mtp_layers.0.hidden_layernorm.weight", + "model.mtp_layers.0.input_layernorm.weight", + "model.mtp_layers.0.post_attention_layernorm.weight", + "model.mtp_layers.0.final_layernorm.weight", + ] + _v28_read = getattr( + self.rollout_engine, + "read_weights_by_name", + None, + ) + if _v28_read is not None: + _v28_resp = _v28_read( + names=_v28_probe_names, truncate_size=8, + ) + _v28_entries = [] + if isinstance(_v28_resp, dict): + _v28_entries = _v28_resp.get("entries", []) + for _ent in _v28_entries: + self.logger.info( + "[SGLangReadBackMTPv3-v28] name=%s " + "status=%s dtype=%s first8=%s body=%s", + _ent.get("name"), + _ent.get("status"), + _ent.get("dtype"), + _ent.get("first8"), + (str(_ent.get("body", ""))[:240]), + ) + else: + self.logger.info( + "[SGLangReadBackMTPv3-v28] rollout_engine " + "lacks read_weights_by_name (engine_type=%s).", + type(self.rollout_engine).__name__, + ) + except Exception as _e_v28rb: + self.logger.warning( + "[SGLangReadBackMTPv3-v28] outer error: %s", + _e_v28rb, + ) _t_call1 = _time.time() self.logger.info( f"[DiagUW] Successfully updated EAGLE draft model " diff --git a/areal/infra/controller/rollout_callback.py b/areal/infra/controller/rollout_callback.py index 202a4c3d2a..68994a75e0 100644 --- a/areal/infra/controller/rollout_callback.py +++ b/areal/infra/controller/rollout_callback.py @@ -93,6 +93,32 @@ def _post(self, endpoint: str, payload: dict[str, Any] | None = None) -> dict: ) raise + def read_weights_by_name( + self, + names, + truncate_size: int = 8, + ) -> dict: + """[v28] Read back SGLang draft weights via the callback chain. + + Forwards to /callback/read_weights_by_name on RolloutController, + which selects the first worker, fetches its RemoteInfEngine + addresses[0], and calls SGLang's /get_weights_by_parameter_name. + """ + payload = { + "names": list(names), + "truncate_size": int(truncate_size), + } + try: + return self._post( + "/callback/read_weights_by_name", payload, + ) + except Exception as e: + logger.warning( + "[DiagMTP][Callback] read_weights_by_name FAILED: %s", + e, + ) + return {"entries": [], "error": str(e)} + def _post_nowait( self, endpoint: str, payload: dict[str, Any] | None = None ) -> Future[dict]: diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 0ad3dab651..da22d7d8cf 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -712,6 +712,28 @@ def update_weights_tensor(): raise return jsonify({"status": "ok"}) + @app.route("/callback/read_weights_by_name", methods=["POST"]) + def read_weights_by_name_route(): + payload = request.get_json() or {} + names = payload.get("names", []) or [] + truncate_size = int(payload.get("truncate_size", 8) or 8) + try: + fut = asyncio.run_coroutine_threadsafe( + self.read_weights_by_name( + names=names, + truncate_size=truncate_size, + ), + self._callback_loop, + ) + entries = fut.result() + return jsonify({"entries": entries}) + except Exception as _e: + logger.warning( + "[DiagMTP] /callback/read_weights_by_name FAILED: %s", + _e, + ) + return jsonify({"entries": [], "error": str(_e)}) + @app.route("/callback/rollout_complete", methods=["POST"]) def rollout_complete(): payload = request.get_json() or {} @@ -1319,6 +1341,85 @@ async def update_weights_from_tensor(self, serialized_payload: dict) -> None: ) raise + async def read_weights_by_name( + self, + names, + truncate_size: int = 8, + ) -> list: + """[v28] Delegate SGLang HTTP read-by-name. + + Uses a lightweight worker RPC to fetch RemoteInfEngine + addresses, then calls SGLang's /get_weights_by_parameter_name + directly over HTTP from the controller process. + """ + import requests as _v28_requests + entries: list = [] + try: + addr_list = await self._collective_rpc_async( + "get_addresses", http_timeout=60.0, + ) + except Exception as _e_addr: + logger.warning( + "[DiagMTP] read_weights_by_name: addr RPC failed: %s", + _e_addr, + ) + addr_list = [] + flat_addrs: list = [] + for a in addr_list or []: + if isinstance(a, (list, tuple)): + flat_addrs.extend(a) + elif a: + flat_addrs.append(a) + if not flat_addrs: + return entries + addr0 = flat_addrs[0] + base = ( + addr0 if str(addr0).startswith("http") + else f"http://{addr0}" + ) + for nm in names: + try: + resp = _v28_requests.post( + f"{base}/get_weights_by_parameter_name", + json={ + "name": nm, + "truncate_size": truncate_size, + }, + timeout=15, + proxies={"http": None, "https": None}, + ) + body = resp.text[:400] + first8 = None + dtype = None + try: + _j = resp.json() + if isinstance(_j, list): + first8 = _j[:8] + elif isinstance(_j, dict): + first8 = ( + _j.get("values") + or _j.get("first8") + ) + dtype = _j.get("dtype") + except Exception: + pass + entries.append({ + "name": nm, + "status": resp.status_code, + "first8": first8, + "dtype": dtype, + "body": body, + }) + except Exception as _e_http: + entries.append({ + "name": nm, + "status": -1, + "first8": None, + "dtype": None, + "body": f"err: {_e_http}", + }) + return entries + def set_version(self, version: int) -> None: with self._version_lock: self._version = version diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 71e70b0b3a..22150585bb 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -1350,6 +1350,17 @@ def onload(self, tags: list[str] | None = None) -> None: onload_req = self.backend.get_onload_request(tags=tags) self._run_request_on_all_servers(onload_req) + def get_addresses(self) -> list: + """[v28] Expose this worker's inference-server addresses. + + Used by RolloutController.read_weights_by_name to reach + SGLang's /get_weights_by_parameter_name endpoint directly. + """ + try: + return list(self.addresses or []) + except Exception: + return [] + def _run_request_on_all_servers(self, req: HttpRequest): import time as _time From 866e6a1e3c11a335127aea8c19cd0dac3655d45e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 17:23:39 +0800 Subject: [PATCH 097/140] feat(engine): fix --- areal/engine/megatron_engine.py | 92 +++++++++++++++++++++++++++++++-- areal/engine/sglang_remote.py | 12 +++++ 2 files changed, 101 insertions(+), 3 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 519bebe161..ef7835b7dc 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -219,10 +219,11 @@ def __init__(self, config: TrainEngineConfig): _os_banner.environ.get( "AREAL_MTP_FP32_MASTER_READ", "1"), } - # [MTPVersionBanner-v28] Added iter16 instrumentation: SGLangReadBackMTPv3-v28(CallbackPath) / MTPBf16ULPProof-v28. - _banner_tags = list(_banner_tags) + ["v27:SGLangReadBackMTPv2-HTTP+MainGrad+Fp32Delta", "v28:SGLangReadBackMTPv3-CallbackPath+MTPBf16ULPProof"] + # [MTPVersionBanner-v29] Added iter16 instrumentation: SGLangReadBackMTPv3-v28(CallbackPath) / MTPBf16ULPProof-v28. + _banner_tags = list(_banner_tags) + ["v27:SGLangReadBackMTPv2-HTTP+MainGrad+Fp32Delta", "v28:SGLangReadBackMTPv3-CallbackPath+MTPBf16ULPProof", + "v29:SigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)+SglangGetAddrsFix"] self.logger.info( - "[MTPVersionBanner-v28] tags=%s flags=%s", + "[MTPVersionBanner-v29] tags=%s flags=%s", ",".join(_banner_tags), _banner_flags, ) try: @@ -4227,6 +4228,91 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: "MTP tensors bf16->fp32 (name=%s).", _upcasted, name, ) + # [MTPSigmaDeltaBf16-v29] Residual-carried bf16 + # quantization of the fp32 MTP payload. This is + # the iter17 functional fix that addresses the + # bf16-ULP flooring 100% confirmed from log.10: + # - LayerNorm weights frozen across consecutive + # MTP sync events (MTPBf16ULPProof-v28 showed + # bf16cast_eq_prev_bf16cast==numel on final/ + # input/hidden/post_attention/token_layernorm) + # - fp32 master delta per step ~2e-6 (log.10 + # MTPFp32Delta-v27), well below bf16 ULP at + # |w|=3 (~1.56e-2), so SGLang's bf16 draft + # weight never moves for these tensors. + # We carry each step's round-off residual into + # the next sync event so cumulative sub-ULP + # deltas eventually "tick" the bf16 weight one + # ULP at a time -- classic Sigma-Delta / noise- + # shaped quantization. This turns the SGLang- + # side bf16 target dtype from a hard floor into + # a dithered representation of the fp32 master + # without introducing per-element ULP jitter + # (pure stochastic rounding would do that). + try: + _sd_on = ( + _os_v24m.environ.get( + "AREAL_MTP_SIGMA_DELTA_BF16", "1", + ) == "1" + ) + except Exception: + _sd_on = True + if _sd_on: + if not hasattr(self, "_mtp_sd_residual"): + self._mtp_sd_residual = {} + _sd_applied = 0 + _sd_nonzero_shift = 0 + for _i in range(_prev_count, len(mtp_hf_tensors)): + _nm_sd, _tn_sd = mtp_hf_tensors[_i] + # Only apply to fp32 tensors (the upcast + # step above may have produced fp32; if + # still bf16 here, nothing to do). + if _tn_sd.dtype != _torch_v24m.float32: + continue + _prev_res = self._mtp_sd_residual.get( + _nm_sd + ) + if ( + _prev_res is not None + and _prev_res.shape == _tn_sd.shape + and _prev_res.device == _tn_sd.device + ): + _u = _tn_sd + _prev_res + else: + _u = _tn_sd + # Round-nearest-even to bf16, then back + # to fp32 to compute new residual. + _bf16 = _u.to(_torch_v24m.bfloat16) + _bb = _bf16.float() + _new_res = (_u - _bb).detach().clone() + self._mtp_sd_residual[_nm_sd] = _new_res + # Count how many elements' bf16 state + # differs from the plain RNE(_tn_sd) + # baseline; this is the diagnostic + # "sigma-delta shift". + try: + _baseline_bf16 = _tn_sd.to( + _torch_v24m.bfloat16 + ) + _sd_nonzero_shift += int( + (_bf16 != _baseline_bf16) + .sum() + .item() + ) + except Exception: + pass + mtp_hf_tensors[_i] = ( + _nm_sd, _bf16.contiguous(), + ) + _sd_applied += 1 + if _sd_applied > 0: + self.logger.info( + "[MTPSigmaDeltaBf16-v29] name=%s " + "applied=%d total_shifted_elems=%d", + name, _sd_applied, + _sd_nonzero_shift, + ) + # [MTPSigmaDeltaBf16-v29] END # [MTPWeightDeltaD15] version-to-version # abs_mean delta tracker. if not hasattr(self, "_mtp_d15_prev_abs_mean"): diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 7abe90e7f9..6880843a02 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -463,6 +463,18 @@ def update_weights_from_distributed( """Update weights from distributed memory.""" return self._engine.update_weights_from_distributed(meta, param_specs) + def get_addresses(self) -> list: + """[v29] Delegate address lookup to the composed RemoteInfEngine. + + Needed so RolloutController._collective_rpc_async("get_addresses") + resolves on the RemoteSGLangEngine wrapper registered on the + rollout worker. + """ + try: + return self._engine.get_addresses() + except Exception: + return [] + def update_weights_from_tensor( self, named_tensors: list[tuple[str, torch.Tensor]] | None = None, From dc6571bed35c515251cb64bec2dcf2aab1f65652 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 19:58:56 +0800 Subject: [PATCH 098/140] feat(megatron_engine): verify log --- areal/engine/megatron_engine.py | 223 ++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ef7835b7dc..4185cb4e34 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -207,6 +207,7 @@ def __init__(self, config: TrainEngineConfig): "v23:NoDPCollective+TPConsensus+MainParamViewDirect", "v24:TPDtypeSymmetric+NonPPHeadAlsoFp32", "v25:AcceptEMA+GradProbe+PostOptim+Bf16Drift+SendPreBcast", + "v30:MTPRelativeSpeed+SignalAudit+ReadBackInternal(AREAL_MTP_V30_DIAG)", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": @@ -218,6 +219,9 @@ def __init__(self, config: TrainEngineConfig): "AREAL_MTP_FP32_MASTER_READ": _os_banner.environ.get( "AREAL_MTP_FP32_MASTER_READ", "1"), + "AREAL_MTP_V30_DIAG": + _os_banner.environ.get( + "AREAL_MTP_V30_DIAG", "1"), } # [MTPVersionBanner-v29] Added iter16 instrumentation: SGLangReadBackMTPv3-v28(CallbackPath) / MTPBf16ULPProof-v28. _banner_tags = list(_banner_tags) + ["v27:SGLangReadBackMTPv2-HTTP+MainGrad+Fp32Delta", "v28:SGLangReadBackMTPv3-CallbackPath+MTPBf16ULPProof", @@ -4483,6 +4487,225 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"buffer_size={buffer_size}" ) + # [MTPRelativeSpeed-v30] Collect fp32 master-weight snapshots + # for BOTH MTP-draft parameters AND a canonical backbone sample. + # Computed BEFORE the bf16 serialization/broadcast so the ratio + # reflects what actually moved in the optimizer master copy + # (i.e. Adam-driven signal, not bf16-rounded wire values). + # + # H1 (MTP relative speed too slow) is confirmed iff ratio <= 0.1 + # persistently across weight syncs while accept_ema is declining. + # H1 is rejected if ratio >= 1.0 (MTP is updating at least as + # fast as the backbone in relative terms). + try: + import os as _os_v30 + _v30_on = _os_v30.environ.get("AREAL_MTP_V30_DIAG", "1") == "1" + except Exception: + _v30_on = False + if _v30_on and mtp_hf_tensors: + try: + import torch as _torch_v30 + # ---- Current MTP fp32 master norm (from optimizer) ---- + _mtp_master_norm_sq = 0.0 + _mtp_master_count = 0 + _bb_master_norm_sq = 0.0 + _bb_master_count = 0 + for _n_v30, _p_v30 in get_named_parameters( + self.model, num_moe_experts + ): + if ".experts." in _n_v30: + continue + _data_v30 = getattr(_p_v30, "main_grad", None) + # main_grad is the grad buffer; master weights sit in + # _p_v30.data.float() for Megatron DistOpt. Use .data. + _t_v30 = _p_v30.data + if _t_v30 is None: + continue + _f_v30 = _t_v30.detach().float() + _sq = float((_f_v30 * _f_v30).sum().item()) + _num = int(_f_v30.numel()) + if ".mtp." in _n_v30: + _mtp_master_norm_sq += _sq + _mtp_master_count += _num + else: + _bb_master_norm_sq += _sq + _bb_master_count += _num + _mtp_master_norm = (_mtp_master_norm_sq ** 0.5) + _bb_master_norm = (_bb_master_norm_sq ** 0.5) + # ---- Delta vs previous sync ---- + _prev_mtp = getattr(self, "_v30_prev_mtp_master_norm", None) + _prev_bb = getattr(self, "_v30_prev_bb_master_norm", None) + self._v30_prev_mtp_master_norm = _mtp_master_norm + self._v30_prev_bb_master_norm = _bb_master_norm + _rel_speed = None + _d_mtp_rel = None + _d_bb_rel = None + if _prev_mtp is not None and _prev_bb is not None: + _d_mtp = abs(_mtp_master_norm - _prev_mtp) + _d_bb = abs(_bb_master_norm - _prev_bb) + if _mtp_master_norm > 0: + _d_mtp_rel = _d_mtp / _mtp_master_norm + if _bb_master_norm > 0: + _d_bb_rel = _d_bb / _bb_master_norm + if ( + _d_mtp_rel is not None + and _d_bb_rel is not None + and _d_bb_rel > 0 + ): + _rel_speed = _d_mtp_rel / _d_bb_rel + try: + _rank_v30 = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() else 0 + ) + except Exception: + _rank_v30 = 0 + if _rank_v30 == 0: + self.logger.info( + "[MTPRelativeSpeed-v30] version=%s " + "|W_MTP|=%.4e (n=%d) |W_BB|=%.4e (n=%d) " + "d|W_MTP|/|W_MTP|=%s d|W_BB|/|W_BB|=%s " + "rel_speed=%s", + str(meta.version), + _mtp_master_norm, _mtp_master_count, + _bb_master_norm, _bb_master_count, + ("%.4e" % _d_mtp_rel) if _d_mtp_rel is not None else "NA", + ("%.4e" % _d_bb_rel) if _d_bb_rel is not None else "NA", + ("%.4f" % _rel_speed) if _rel_speed is not None else "NA", + ) + # [MTPLossSignalAudit-v30] consolidate the signals + # needed to arbitrate H1 vs H4 on a single line. + _mtp_loss_ema = getattr( + self, "_last_logged_mtp_loss_ema", None + ) + _mtp_loss_raw = getattr( + self, "_last_logged_mtp_loss_raw", None + ) + _task_reward = getattr( + self, "_last_logged_task_reward", None + ) + _entropy_avg = getattr( + self, "_last_logged_entropy_avg", None + ) + _accept_ema = getattr( + self, "_last_logged_accept_ema", None + ) + self.logger.info( + "[MTPLossSignalAudit-v30] version=%s " + "rel_speed=%s mtp_loss_ema=%s mtp_loss_raw=%s " + "task_reward=%s entropy_avg=%s accept_ema=%s " + "H1=%s H4=%s", + str(meta.version), + ("%.4f" % _rel_speed) if _rel_speed is not None else "NA", + ("%.4f" % _mtp_loss_ema) if isinstance(_mtp_loss_ema, (int, float)) else "NA", + ("%.4f" % _mtp_loss_raw) if isinstance(_mtp_loss_raw, (int, float)) else "NA", + ("%.4f" % _task_reward) if isinstance(_task_reward, (int, float)) else "NA", + ("%.4f" % _entropy_avg) if isinstance(_entropy_avg, (int, float)) else "NA", + ("%.4f" % _accept_ema) if isinstance(_accept_ema, (int, float)) else "NA", + ( + "CONFIRMED" if ( + isinstance(_rel_speed, float) + and _rel_speed <= 0.1 + ) else ( + "REJECTED" if ( + isinstance(_rel_speed, float) + and _rel_speed >= 1.0 + ) else "UNKNOWN" + ) + ), + ( + "SUSPECT" if ( + isinstance(_task_reward, (int, float)) + and _task_reward >= 0.9 + and isinstance(_mtp_loss_ema, (int, float)) + and _mtp_loss_ema <= 0.6 + ) else "NORMAL" + ), + ) + except Exception as _e_v30: + try: + self.logger.warning( + "[MTPRelativeSpeed-v30] computation failed: %r", + _e_v30, + ) + except Exception: + pass + # [SGLangReadBackMTPv4-v30] Best-effort internal-Python read-back + # of a single MTP tensor from SGLang workers, bypassing the HTTP + # /get_weights_by_parameter_name endpoint (which returns 404 on + # SGLang 0.5.9). If readback succeeds, the wire-side |W| is + # compared with the training-side |W| just serialized; mismatch + # >= 1e-3 indicates H2 (SGLang not receiving updates). + if _v30_on and mtp_hf_tensors: + try: + _re_v30 = self.rollout_engine + _inner = getattr(_re_v30, "_engine", None) + _inner2 = getattr(_inner, "_engine", None) if _inner is not None else None + _inner3 = getattr(_inner, "inner_engine", None) if _inner is not None else None + _callable_engine = None + for _cand in (_inner3, _inner2, _inner, _re_v30): + if _cand is None: + continue + if hasattr(_cand, "get_weights_by_parameter_name"): + _callable_engine = _cand + break + try: + _rank_rb = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() else 0 + ) + except Exception: + _rank_rb = 0 + if _rank_rb == 0: + if _callable_engine is None: + self.logger.info( + "[SGLangReadBackMTPv4-v30] unavailable: " + "no get_weights_by_parameter_name on " + "rollout_engine chain (types=%s/%s/%s/%s)", + type(_re_v30).__name__, + type(_inner).__name__ if _inner is not None else "None", + type(_inner2).__name__ if _inner2 is not None else "None", + type(_inner3).__name__ if _inner3 is not None else "None", + ) + else: + _probe_name, _probe_tensor = mtp_hf_tensors[0] + _expected_norm = float( + _probe_tensor.detach().float().norm().item() + ) + try: + _wire = _callable_engine.get_weights_by_parameter_name( + _probe_name + ) + if hasattr(_wire, "float"): + _wire_norm = float(_wire.float().norm().item()) + else: + _wire_norm = float("nan") + _gap = abs(_wire_norm - _expected_norm) + _rel_gap = ( + _gap / max(_expected_norm, 1e-12) + ) + self.logger.info( + "[SGLangReadBackMTPv4-v30] name=%s " + "expected_norm=%.6e wire_norm=%.6e " + "rel_gap=%.4e H2=%s", + _probe_name, _expected_norm, _wire_norm, + _rel_gap, + "CONFIRMED" if _rel_gap >= 1e-3 else "REJECTED", + ) + except Exception as _e_rb: + self.logger.info( + "[SGLangReadBackMTPv4-v30] readback " + "call failed: %r (engine=%s)", + _e_rb, type(_callable_engine).__name__, + ) + except Exception as _e_rb_outer: + try: + self.logger.warning( + "[SGLangReadBackMTPv4-v30] outer failure: %r", + _e_rb_outer, + ) + except Exception: + pass if mtp_hf_tensors: # [v5-F3] Compute norms for ALL tensors (was: only first 5). # [v5-F5] Track prev norm per-tensor to surface drift direction From dcb2e44974c4dd3904790799caf0d6eaa11c0e60 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 21:02:28 +0800 Subject: [PATCH 099/140] fix(engine): mtp issue --- areal/engine/megatron_engine.py | 317 +++++++++++++++++--------------- 1 file changed, 169 insertions(+), 148 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 4185cb4e34..99bd4f0416 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -207,7 +207,7 @@ def __init__(self, config: TrainEngineConfig): "v23:NoDPCollective+TPConsensus+MainParamViewDirect", "v24:TPDtypeSymmetric+NonPPHeadAlsoFp32", "v25:AcceptEMA+GradProbe+PostOptim+Bf16Drift+SendPreBcast", - "v30:MTPRelativeSpeed+SignalAudit+ReadBackInternal(AREAL_MTP_V30_DIAG)", + "v31:MTPRelativeSpeed+SignalAudit+ReadBackHTTP(AREAL_MTP_V30_DIAG)", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": @@ -4487,66 +4487,68 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"buffer_size={buffer_size}" ) - # [MTPRelativeSpeed-v30] Collect fp32 master-weight snapshots - # for BOTH MTP-draft parameters AND a canonical backbone sample. - # Computed BEFORE the bf16 serialization/broadcast so the ratio - # reflects what actually moved in the optimizer master copy - # (i.e. Adam-driven signal, not bf16-rounded wire values). + # [MTPRelativeSpeed-v31] Measure fp32 |W_MTP| from the already- + # upcast mtp_hf_tensors list (v16 AREAL_MTP_FP32_BROADCAST=1 + # guarantees these are fp32), and fp32 |W_BB| by promoting each + # backbone bf16 tensor to fp32 during reduction only. v30 read + # _p.data (bf16 copy) which had ULP=2.2 per element on |W|~284, + # making d|W|/|W| dominated by quantization noise instead of the + # actual Adam master-weight movement. # - # H1 (MTP relative speed too slow) is confirmed iff ratio <= 0.1 - # persistently across weight syncs while accept_ema is declining. - # H1 is rejected if ratio >= 1.0 (MTP is updating at least as - # fast as the backbone in relative terms). + # H1 judgement: + # rel_speed <= 0.1 persistent -> CONFIRMED (MTP too slow) + # rel_speed >= 1.0 persistent -> REJECTED + # otherwise UNKNOWN try: - import os as _os_v30 - _v30_on = _os_v30.environ.get("AREAL_MTP_V30_DIAG", "1") == "1" + import os as _os_v31 + _v31_on = _os_v31.environ.get("AREAL_MTP_V30_DIAG", "1") == "1" except Exception: - _v30_on = False - if _v30_on and mtp_hf_tensors: + _v31_on = False + if _v31_on and mtp_hf_tensors: try: - import torch as _torch_v30 - # ---- Current MTP fp32 master norm (from optimizer) ---- - _mtp_master_norm_sq = 0.0 - _mtp_master_count = 0 - _bb_master_norm_sq = 0.0 - _bb_master_count = 0 - for _n_v30, _p_v30 in get_named_parameters( + import torch as _torch_v31 + # ---- MTP fp32 norm (already fp32 after v16 upcast) ---- + _mtp_sq = 0.0 + _mtp_cnt = 0 + for _nm, _tn in mtp_hf_tensors: + _f = _tn.detach() + if _f.dtype != _torch_v31.float32: + _f = _f.float() + _mtp_sq += float((_f * _f).sum().item()) + _mtp_cnt += int(_f.numel()) + _mtp_norm = _mtp_sq ** 0.5 + # ---- Backbone fp32 norm (promote bf16 -> fp32 on-fly) ---- + _bb_sq = 0.0 + _bb_cnt = 0 + for _nbb, _pbb in get_named_parameters( self.model, num_moe_experts ): - if ".experts." in _n_v30: + if ".experts." in _nbb: continue - _data_v30 = getattr(_p_v30, "main_grad", None) - # main_grad is the grad buffer; master weights sit in - # _p_v30.data.float() for Megatron DistOpt. Use .data. - _t_v30 = _p_v30.data - if _t_v30 is None: + if ".mtp." in _nbb: continue - _f_v30 = _t_v30.detach().float() - _sq = float((_f_v30 * _f_v30).sum().item()) - _num = int(_f_v30.numel()) - if ".mtp." in _n_v30: - _mtp_master_norm_sq += _sq - _mtp_master_count += _num - else: - _bb_master_norm_sq += _sq - _bb_master_count += _num - _mtp_master_norm = (_mtp_master_norm_sq ** 0.5) - _bb_master_norm = (_bb_master_norm_sq ** 0.5) - # ---- Delta vs previous sync ---- - _prev_mtp = getattr(self, "_v30_prev_mtp_master_norm", None) - _prev_bb = getattr(self, "_v30_prev_bb_master_norm", None) - self._v30_prev_mtp_master_norm = _mtp_master_norm - self._v30_prev_bb_master_norm = _bb_master_norm - _rel_speed = None + _tbb = _pbb.detach() + if _tbb is None: + continue + _tbb = _tbb.float() + _bb_sq += float((_tbb * _tbb).sum().item()) + _bb_cnt += int(_tbb.numel()) + _bb_norm = _bb_sq ** 0.5 + # ---- Delta bookkeeping ---- + _prev_mtp = getattr(self, "_v31_prev_mtp_norm", None) + _prev_bb = getattr(self, "_v31_prev_bb_norm", None) + self._v31_prev_mtp_norm = _mtp_norm + self._v31_prev_bb_norm = _bb_norm _d_mtp_rel = None _d_bb_rel = None + _rel_speed = None if _prev_mtp is not None and _prev_bb is not None: - _d_mtp = abs(_mtp_master_norm - _prev_mtp) - _d_bb = abs(_bb_master_norm - _prev_bb) - if _mtp_master_norm > 0: - _d_mtp_rel = _d_mtp / _mtp_master_norm - if _bb_master_norm > 0: - _d_bb_rel = _d_bb / _bb_master_norm + _d_mtp = abs(_mtp_norm - _prev_mtp) + _d_bb = abs(_bb_norm - _prev_bb) + if _mtp_norm > 0: + _d_mtp_rel = _d_mtp / _mtp_norm + if _bb_norm > 0: + _d_bb_rel = _d_bb / _bb_norm if ( _d_mtp_rel is not None and _d_bb_rel is not None @@ -4554,118 +4556,126 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ): _rel_speed = _d_mtp_rel / _d_bb_rel try: - _rank_v30 = ( + _rk = ( torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 ) except Exception: - _rank_v30 = 0 - if _rank_v30 == 0: + _rk = 0 + if _rk == 0: self.logger.info( - "[MTPRelativeSpeed-v30] version=%s " - "|W_MTP|=%.4e (n=%d) |W_BB|=%.4e (n=%d) " + "[MTPRelativeSpeed-v31] version=%s " + "|W_MTP|_fp32=%.6e (n=%d) " + "|W_BB|_fp32=%.6e (n=%d) " "d|W_MTP|/|W_MTP|=%s d|W_BB|/|W_BB|=%s " "rel_speed=%s", - str(meta.version), - _mtp_master_norm, _mtp_master_count, - _bb_master_norm, _bb_master_count, + str(meta.version), _mtp_norm, _mtp_cnt, + _bb_norm, _bb_cnt, ("%.4e" % _d_mtp_rel) if _d_mtp_rel is not None else "NA", ("%.4e" % _d_bb_rel) if _d_bb_rel is not None else "NA", ("%.4f" % _rel_speed) if _rel_speed is not None else "NA", ) - # [MTPLossSignalAudit-v30] consolidate the signals - # needed to arbitrate H1 vs H4 on a single line. - _mtp_loss_ema = getattr( - self, "_last_logged_mtp_loss_ema", None - ) - _mtp_loss_raw = getattr( - self, "_last_logged_mtp_loss_raw", None - ) - _task_reward = getattr( - self, "_last_logged_task_reward", None - ) - _entropy_avg = getattr( - self, "_last_logged_entropy_avg", None - ) + # ---- [MTPLossSignalAudit-v31] real attribute names ---- + _mtp_loss_ema = getattr(self, "_mtp_loss_ema", None) + _mtp_loss_val = getattr(self, "_mtp_loss_value", None) + _mtp_lr_cache = getattr(self, "_last_logged_mtp_lr", None) + # task_reward/entropy come from stats_tracker; fall + # back gracefully across API shapes. + _task_reward = None + _entropy_avg = None + try: + from areal.utils import stats_tracker as _st_v31 + for _attr in ("get", "peek", "last"): + _fn = getattr(_st_v31, _attr, None) + if not callable(_fn): + continue + try: + _task_reward = _fn( + "ppo_actor/task_reward/avg" + ) + _entropy_avg = _fn( + "ppo_actor/update/entropy/avg" + ) + if _task_reward is not None: + break + except Exception: + continue + except Exception: + pass + if _task_reward is None: + _task_reward = getattr( + self, "_last_task_reward_avg", None + ) + if _entropy_avg is None: + _entropy_avg = getattr( + self, "_last_entropy_avg", None + ) _accept_ema = getattr( - self, "_last_logged_accept_ema", None + self, "_last_accept_ema256", None ) + _h1 = "UNKNOWN" + if isinstance(_rel_speed, float): + if _rel_speed <= 0.1: + _h1 = "CONFIRMED" + elif _rel_speed >= 1.0: + _h1 = "REJECTED" + _h4 = "NORMAL" + if ( + isinstance(_task_reward, (int, float)) + and _task_reward >= 0.9 + and isinstance(_mtp_loss_ema, (int, float)) + and _mtp_loss_ema <= 0.6 + ): + _h4 = "SUSPECT" self.logger.info( - "[MTPLossSignalAudit-v30] version=%s " - "rel_speed=%s mtp_loss_ema=%s mtp_loss_raw=%s " + "[MTPLossSignalAudit-v31] version=%s " + "rel_speed=%s |W_MTP|=%.6e |W_BB|=%.6e " + "mtp_loss_ema=%s mtp_loss_raw=%s mtp_lr=%s " "task_reward=%s entropy_avg=%s accept_ema=%s " "H1=%s H4=%s", str(meta.version), ("%.4f" % _rel_speed) if _rel_speed is not None else "NA", + _mtp_norm, _bb_norm, ("%.4f" % _mtp_loss_ema) if isinstance(_mtp_loss_ema, (int, float)) else "NA", - ("%.4f" % _mtp_loss_raw) if isinstance(_mtp_loss_raw, (int, float)) else "NA", + ("%.4f" % _mtp_loss_val) if isinstance(_mtp_loss_val, (int, float)) else "NA", + ("%.3e" % _mtp_lr_cache) if isinstance(_mtp_lr_cache, (int, float)) else "NA", ("%.4f" % _task_reward) if isinstance(_task_reward, (int, float)) else "NA", ("%.4f" % _entropy_avg) if isinstance(_entropy_avg, (int, float)) else "NA", ("%.4f" % _accept_ema) if isinstance(_accept_ema, (int, float)) else "NA", - ( - "CONFIRMED" if ( - isinstance(_rel_speed, float) - and _rel_speed <= 0.1 - ) else ( - "REJECTED" if ( - isinstance(_rel_speed, float) - and _rel_speed >= 1.0 - ) else "UNKNOWN" - ) - ), - ( - "SUSPECT" if ( - isinstance(_task_reward, (int, float)) - and _task_reward >= 0.9 - and isinstance(_mtp_loss_ema, (int, float)) - and _mtp_loss_ema <= 0.6 - ) else "NORMAL" - ), + _h1, _h4, ) - except Exception as _e_v30: + except Exception as _e_v31: try: self.logger.warning( - "[MTPRelativeSpeed-v30] computation failed: %r", - _e_v30, + "[MTPRelativeSpeed-v31] failed: %r", _e_v31, ) except Exception: pass - # [SGLangReadBackMTPv4-v30] Best-effort internal-Python read-back - # of a single MTP tensor from SGLang workers, bypassing the HTTP - # /get_weights_by_parameter_name endpoint (which returns 404 on - # SGLang 0.5.9). If readback succeeds, the wire-side |W| is - # compared with the training-side |W| just serialized; mismatch - # >= 1e-3 indicates H2 (SGLang not receiving updates). - if _v30_on and mtp_hf_tensors: + # [SGLangReadBackMTPv5-v31] HTTP-direct readback via RolloutCallback. + # spec_v1.log.12 proved rollout_engine._engine is None in PR#1176 + # RolloutCallback mode (pure HTTP proxy). v31 dispatches a best- + # effort POST to {controller_addr}/callback/get_mtp_weight_norm + # and treats any non-2xx (incl. 404 when the controller does not + # implement the endpoint) as "unavailable" rather than a failure. + # If the endpoint IS present, the response JSON should contain + # {"norm": , "name": } allowing H2 arbitration. + if _v31_on and mtp_hf_tensors: try: - _re_v30 = self.rollout_engine - _inner = getattr(_re_v30, "_engine", None) - _inner2 = getattr(_inner, "_engine", None) if _inner is not None else None - _inner3 = getattr(_inner, "inner_engine", None) if _inner is not None else None - _callable_engine = None - for _cand in (_inner3, _inner2, _inner, _re_v30): - if _cand is None: - continue - if hasattr(_cand, "get_weights_by_parameter_name"): - _callable_engine = _cand - break + _re = self.rollout_engine + _addr = getattr(_re, "controller_addr", None) try: - _rank_rb = ( + _rk2 = ( torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 ) except Exception: - _rank_rb = 0 - if _rank_rb == 0: - if _callable_engine is None: + _rk2 = 0 + if _rk2 == 0: + if _addr is None: self.logger.info( - "[SGLangReadBackMTPv4-v30] unavailable: " - "no get_weights_by_parameter_name on " - "rollout_engine chain (types=%s/%s/%s/%s)", - type(_re_v30).__name__, - type(_inner).__name__ if _inner is not None else "None", - type(_inner2).__name__ if _inner2 is not None else "None", - type(_inner3).__name__ if _inner3 is not None else "None", + "[SGLangReadBackMTPv5-v31] unavailable: " + "rollout_engine=%s has no controller_addr", + type(_re).__name__, ) else: _probe_name, _probe_tensor = mtp_hf_tensors[0] @@ -4673,35 +4683,46 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _probe_tensor.detach().float().norm().item() ) try: - _wire = _callable_engine.get_weights_by_parameter_name( - _probe_name + import requests as _rq_v31 + _url = ( + f"http://{_addr}/callback/" + f"get_mtp_weight_norm" ) - if hasattr(_wire, "float"): - _wire_norm = float(_wire.float().norm().item()) - else: - _wire_norm = float("nan") - _gap = abs(_wire_norm - _expected_norm) - _rel_gap = ( - _gap / max(_expected_norm, 1e-12) - ) - self.logger.info( - "[SGLangReadBackMTPv4-v30] name=%s " - "expected_norm=%.6e wire_norm=%.6e " - "rel_gap=%.4e H2=%s", - _probe_name, _expected_norm, _wire_norm, - _rel_gap, - "CONFIRMED" if _rel_gap >= 1e-3 else "REJECTED", + _resp = _rq_v31.post( + _url, + json={"name": _probe_name}, + timeout=30.0, + proxies={"http": None, "https": None}, ) + if _resp.status_code == 200: + _j = _resp.json() + _wire = float(_j.get("norm", float("nan"))) + _gap = abs(_wire - _expected_norm) + _rel_gap = _gap / max(_expected_norm, 1e-12) + self.logger.info( + "[SGLangReadBackMTPv5-v31] name=%s " + "expected_norm=%.6e wire_norm=%.6e " + "rel_gap=%.4e H2=%s", + _probe_name, _expected_norm, _wire, + _rel_gap, + "CONFIRMED" if _rel_gap >= 1e-3 + else "REJECTED", + ) + else: + self.logger.info( + "[SGLangReadBackMTPv5-v31] unavailable: " + "endpoint status=%d url=%s", + _resp.status_code, _url, + ) except Exception as _e_rb: self.logger.info( - "[SGLangReadBackMTPv4-v30] readback " - "call failed: %r (engine=%s)", - _e_rb, type(_callable_engine).__name__, + "[SGLangReadBackMTPv5-v31] http failure: " + "%r addr=%s", _e_rb, _addr, ) except Exception as _e_rb_outer: try: self.logger.warning( - "[SGLangReadBackMTPv4-v30] outer failure: %r", + "[SGLangReadBackMTPv5-v31] outer failure: %r", _e_rb_outer, ) except Exception: From 3ee9a06e2eec7e1a7af2bdf0dd4188e15e962c72 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 00:49:00 +0800 Subject: [PATCH 100/140] feat(controller): fix --- areal/engine/megatron_engine.py | 139 ++++++++++++------- areal/infra/controller/rollout_controller.py | 126 +++++++++++++++++ 2 files changed, 213 insertions(+), 52 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 99bd4f0416..ec25659cb2 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -207,7 +207,7 @@ def __init__(self, config: TrainEngineConfig): "v23:NoDPCollective+TPConsensus+MainParamViewDirect", "v24:TPDtypeSymmetric+NonPPHeadAlsoFp32", "v25:AcceptEMA+GradProbe+PostOptim+Bf16Drift+SendPreBcast", - "v31:MTPRelativeSpeed+SignalAudit+ReadBackHTTP(AREAL_MTP_V30_DIAG)", + "v32:AuditLatch+ReadBackEndpoint(AREAL_MTP_V30_DIAG)", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": @@ -2689,6 +2689,20 @@ def export_stats(self) -> dict[str, float]: group=mpu.get_pipeline_model_parallel_group(), ) data.update(data_list[0]) + # [v32] Snapshot the reduced stats dict so the MTP weight- + # sync path can read task_reward / entropy / accept_rate + # without re-entering stats_tracker (which would reset the + # accumulators on export). + try: + self._last_stats_snapshot_v32 = dict(data) + _tr = data.get("ppo_actor/task_reward/avg") + _ea = data.get("ppo_actor/update/entropy/avg") + if isinstance(_tr, (int, float)): + self._last_task_reward_avg = float(_tr) + if isinstance(_ea, (int, float)): + self._last_entropy_avg = float(_ea) + except Exception: + pass return data def offload(self) -> None: @@ -4579,29 +4593,24 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _mtp_loss_ema = getattr(self, "_mtp_loss_ema", None) _mtp_loss_val = getattr(self, "_mtp_loss_value", None) _mtp_lr_cache = getattr(self, "_last_logged_mtp_lr", None) - # task_reward/entropy come from stats_tracker; fall - # back gracefully across API shapes. + # [v32] Read task_reward / entropy from the + # engine-side snapshot populated by our + # export_stats override (see export_stats below). + # The v31 stats_tracker.get('') path + # returned an empty DistributedStatsTracker (get + # is keyed by TRACKER name, not stat name). + _latest = getattr( + self, "_last_stats_snapshot_v32", None + ) _task_reward = None _entropy_avg = None - try: - from areal.utils import stats_tracker as _st_v31 - for _attr in ("get", "peek", "last"): - _fn = getattr(_st_v31, _attr, None) - if not callable(_fn): - continue - try: - _task_reward = _fn( - "ppo_actor/task_reward/avg" - ) - _entropy_avg = _fn( - "ppo_actor/update/entropy/avg" - ) - if _task_reward is not None: - break - except Exception: - continue - except Exception: - pass + if isinstance(_latest, dict): + _task_reward = _latest.get( + "ppo_actor/task_reward/avg" + ) + _entropy_avg = _latest.get( + "ppo_actor/update/entropy/avg" + ) if _task_reward is None: _task_reward = getattr( self, "_last_task_reward_avg", None @@ -4678,46 +4687,72 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: type(_re).__name__, ) else: - _probe_name, _probe_tensor = mtp_hf_tensors[0] - _expected_norm = float( - _probe_tensor.detach().float().norm().item() + # [v32] Probe up to 3 MTP tensors and aggregate + # the biggest rel_gap seen, so a single layer- + # name mismatch does not falsely flag H2. + _probes = mtp_hf_tensors[:3] + _wire_url = ( + f"http://{_addr}/callback/" + f"get_mtp_weight_norm" ) + _max_rel_gap = 0.0 + _any_ok = False + _status_trace = [] try: - import requests as _rq_v31 - _url = ( - f"http://{_addr}/callback/" - f"get_mtp_weight_norm" - ) - _resp = _rq_v31.post( - _url, - json={"name": _probe_name}, - timeout=30.0, - proxies={"http": None, "https": None}, - ) - if _resp.status_code == 200: - _j = _resp.json() - _wire = float(_j.get("norm", float("nan"))) - _gap = abs(_wire - _expected_norm) - _rel_gap = _gap / max(_expected_norm, 1e-12) + import requests as _rq_v32 + for _pn, _pt in _probes: + _exp = float( + _pt.detach().float().norm().item() + ) + _resp = _rq_v32.post( + _wire_url, + json={"name": _pn}, + timeout=30.0, + proxies={"http": None, "https": None}, + ) + _status_trace.append( + f"{_pn}:{_resp.status_code}" + ) + if _resp.status_code == 200: + _any_ok = True + try: + _jj = _resp.json() + except Exception: + _jj = {} + _wn = float( + _jj.get("norm", float("nan")) + ) + _gap = abs(_wn - _exp) + _rg = _gap / max(_exp, 1e-12) + if _rg > _max_rel_gap: + _max_rel_gap = _rg + self.logger.info( + "[SGLangReadBackMTPv6-v32] " + "name=%s exp=%.6e wire=%.6e " + "rel_gap=%.4e", + _pn, _exp, _wn, _rg, + ) + if _any_ok: self.logger.info( - "[SGLangReadBackMTPv5-v31] name=%s " - "expected_norm=%.6e wire_norm=%.6e " - "rel_gap=%.4e H2=%s", - _probe_name, _expected_norm, _wire, - _rel_gap, - "CONFIRMED" if _rel_gap >= 1e-3 + "[SGLangReadBackMTPv6-v32] " + "aggregate max_rel_gap=%.4e " + "H2=%s trace=%s", + _max_rel_gap, + "CONFIRMED" if _max_rel_gap >= 1e-3 else "REJECTED", + _status_trace, ) else: self.logger.info( - "[SGLangReadBackMTPv5-v31] unavailable: " - "endpoint status=%d url=%s", - _resp.status_code, _url, + "[SGLangReadBackMTPv6-v32] " + "unavailable: trace=%s url=%s", + _status_trace, _wire_url, ) except Exception as _e_rb: self.logger.info( - "[SGLangReadBackMTPv5-v31] http failure: " - "%r addr=%s", _e_rb, _addr, + "[SGLangReadBackMTPv6-v32] http " + "failure: %r addr=%s trace=%s", + _e_rb, _addr, _status_trace, ) except Exception as _e_rb_outer: try: diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index da22d7d8cf..7261331935 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -744,6 +744,132 @@ def rollout_complete(): except Exception as e: return jsonify({"error": str(e)}), 500 + # ------------------------------------------------------------ + # [v32] /callback/get_mtp_weight_norm + # ------------------------------------------------------------ + # Proxy endpoint that lets the training-side MegatronEngine + # confirm whether MTP weights it just pushed via + # /update_weights_from_tensor actually landed on the SGLang + # server. Calls SGLang's built-in /get_weights_by_name on + # the first registered inference server (rank 0) and returns + # a scalar Frobenius norm of the parameter plus its dtype. + # + # Payload: {"name": } + # Response: {"name": , "norm": , "dtype": , + # "numel": , "server": } + # + # Any transport failure or unknown param returns HTTP 200 + # with {"error": , "server": } so that + # the training side can distinguish between "endpoint + # missing" (prev versions returning 404/500) and "real H2 + # signal". + @app.route("/callback/get_mtp_weight_norm", methods=["POST"]) + def get_mtp_weight_norm(): + payload = request.get_json() or {} + _name = payload.get("name") + if not _name: + return jsonify( + {"error": "missing 'name'"} + ), 200 + _srv = None + try: + if not self.server_infos: + return jsonify( + {"error": "no server_infos", + "server": None} + ), 200 + _s0 = self.server_infos[0] + _srv = f"{_s0.host}:{_s0.port}" + try: + import math as _math_v32 + import requests as _rq_v32c + except Exception as _e_imp: + return jsonify( + {"error": f"import fail: {_e_imp!r}", + "server": _srv} + ), 200 + _url = f"http://{_srv}/get_weights_by_name" + # truncate_size=-1 returns the full tensor as a + # (nested) python list so we can compute an exact + # Frobenius norm on the wire side. + try: + _r = _rq_v32c.post( + _url, + json={"name": _name, "truncate_size": -1}, + timeout=60.0, + proxies={"http": None, "https": None}, + ) + except Exception as _e_http: + return jsonify( + { + "error": f"http fail: {_e_http!r}", + "server": _srv, + "url": _url, + } + ), 200 + if _r.status_code != 200: + return jsonify( + { + "error": f"sglang status={_r.status_code}", + "server": _srv, + "url": _url, + "body": _r.text[:400], + } + ), 200 + try: + _j = _r.json() + except Exception as _e_js: + return jsonify( + { + "error": f"json fail: {_e_js!r}", + "server": _srv, + } + ), 200 + # sglang may return {'parameter': ...} OR the raw list. + _param = _j + if isinstance(_j, dict): + _param = _j.get("parameter", _j) + # Flatten arbitrarily-nested lists and compute norm. + _sq = 0.0 + _numel = 0 + def _walk(_x): + nonlocal _sq, _numel + if isinstance(_x, list): + for _y in _x: + _walk(_y) + else: + try: + _v = float(_x) + except Exception: + return + _sq += _v * _v + _numel += 1 + try: + _walk(_param) + except Exception as _e_w: + return jsonify( + { + "error": f"walk fail: {_e_w!r}", + "server": _srv, + } + ), 200 + _norm = _sq ** 0.5 + return jsonify( + { + "name": _name, + "norm": _norm, + "numel": _numel, + "server": _srv, + } + ), 200 + except Exception as _e: + logger.warning( + f"[v32] get_mtp_weight_norm unexpected: {_e!r}" + ) + return jsonify( + {"error": repr(_e), "server": _srv} + ), 200 + @app.errorhandler(Exception) def handle_error(e): logger.error( From 59b319caa7538d71b7b5b9c83ce6234b9b9b3c5a Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 01:50:23 +0800 Subject: [PATCH 101/140] feat(controller): fix1 --- areal/engine/megatron_engine.py | 412 ++++++++++++------- areal/infra/controller/rollout_controller.py | 120 +++--- 2 files changed, 319 insertions(+), 213 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ec25659cb2..7f9faa7526 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -199,35 +199,27 @@ def __init__(self, config: TrainEngineConfig): "v12:OptimDump+Sanity", "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", + "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", + "v33:Bf16PayloadNorm+DeterministicProbe(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", - "v21:MTPFp32MasterRead+DefaultsOn" - "(AREAL_MTP_FP32_MASTER_READ,AREAL_MTP_FP32_BROADCAST=1)", - "v22:CollectiveDeadlockFix+PreScan+EarlyFlush", - "v23:NoDPCollective+TPConsensus+MainParamViewDirect", - "v24:TPDtypeSymmetric+NonPPHeadAlsoFp32", - "v25:AcceptEMA+GradProbe+PostOptim+Bf16Drift+SendPreBcast", - "v32:AuditLatch+ReadBackEndpoint(AREAL_MTP_V30_DIAG)", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": _os_banner.environ.get( "AREAL_MTP_FP32_BROADCAST", "1"), + "AREAL_MTP_SIGMA_DELTA_BF16": + _os_banner.environ.get( + "AREAL_MTP_SIGMA_DELTA_BF16", "1"), "AREAL_MTP_NATIVE_AUTOSCALER": _os_banner.environ.get( "AREAL_MTP_NATIVE_AUTOSCALER", "0"), - "AREAL_MTP_FP32_MASTER_READ": - _os_banner.environ.get( - "AREAL_MTP_FP32_MASTER_READ", "1"), "AREAL_MTP_V30_DIAG": _os_banner.environ.get( "AREAL_MTP_V30_DIAG", "1"), } - # [MTPVersionBanner-v29] Added iter16 instrumentation: SGLangReadBackMTPv3-v28(CallbackPath) / MTPBf16ULPProof-v28. - _banner_tags = list(_banner_tags) + ["v27:SGLangReadBackMTPv2-HTTP+MainGrad+Fp32Delta", "v28:SGLangReadBackMTPv3-CallbackPath+MTPBf16ULPProof", - "v29:SigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)+SglangGetAddrsFix"] self.logger.info( - "[MTPVersionBanner-v29] tags=%s flags=%s", + "[MTPVersionBanner] tags=%s flags=%s", ",".join(_banner_tags), _banner_flags, ) try: @@ -4246,30 +4238,57 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: "MTP tensors bf16->fp32 (name=%s).", _upcasted, name, ) - # [MTPSigmaDeltaBf16-v29] Residual-carried bf16 - # quantization of the fp32 MTP payload. This is - # the iter17 functional fix that addresses the - # bf16-ULP flooring 100% confirmed from log.10: - # - LayerNorm weights frozen across consecutive - # MTP sync events (MTPBf16ULPProof-v28 showed - # bf16cast_eq_prev_bf16cast==numel on final/ - # input/hidden/post_attention/token_layernorm) - # - fp32 master delta per step ~2e-6 (log.10 - # MTPFp32Delta-v27), well below bf16 ULP at - # |w|=3 (~1.56e-2), so SGLang's bf16 draft - # weight never moves for these tensors. - # We carry each step's round-off residual into - # the next sync event so cumulative sub-ULP - # deltas eventually "tick" the bf16 weight one - # ULP at a time -- classic Sigma-Delta / noise- - # shaped quantization. This turns the SGLang- - # side bf16 target dtype from a hard floor into - # a dithered representation of the fp32 master - # without introducing per-element ULP jitter - # (pure stochastic rounding would do that). + # [MTPSigmaDeltaBf16-v28] Residual-carried bf16 + # quantization of the fp32 MTP payload. + # + # PURPOSE + # After v16 upcast the MTP payload is fp32. But + # SGLang 0.5.9's draft model storage is bf16 + # (no fp32-draft knob exists) and its + # default_weight_loader does + # `param.data.copy_(loaded_weight)` which rounds + # fp32->bf16 RNE at the destination. When the + # per-step fp32 delta is smaller than half a + # bf16 ULP (e.g. 2e-6 vs 1.56e-2 for |w|=3 on + # LayerNorm) the draft weight is frozen across + # thousands of steps and accept rate stalls. + # (Confirmed from MTPBf16ULPProof diag in + # iter14-17: bf16cast_eq_prev_bf16cast == numel + # for 5/5 consecutive syncs on all LayerNorm + # MTP params.) + # + # FIX + # Per-tensor residual r[name] (fp32) accumulates + # round-off; each sync we send + # bf16 = RNE(fp32 + r_prev) + # r_new = (fp32 + r_prev) - bf16 + # Cumulative sub-ULP deltas eventually cross the + # bf16 ULP and "tick" the draft weight one ULP + # at a time (classic Sigma-Delta quantization). + # Unlike per-element stochastic rounding this + # is deterministic and preserves monotonic + # sub-ULP trajectories. + # + # NOTES + # * slime/verl do not address this. Research of + # https://github.com/THUDM/slime , + # https://github.com/volcengine/verl , SGLang + # v0.5.9 and Megatron-LM core_r0.16.0 confirms + # they all ship bf16 round-to-nearest. See + # megatron distrib_optimizer.py + # _copy_main_params_to_model_params (plain + # copy_) and sglang weight_utils.py + # default_weight_loader (plain copy_). + # * Only bf16 storage on SGLang side is affected. + # If AREAL_MTP_FP32_BROADCAST=0 or upstream + # already materialised fp32, we are a no-op. + # * Only MTP-draft tensors go through this + # block; all other params are untouched. + # + # Gate: AREAL_MTP_SIGMA_DELTA_BF16 (default "1"). try: _sd_on = ( - _os_v24m.environ.get( + _os_v16.environ.get( "AREAL_MTP_SIGMA_DELTA_BF16", "1", ) == "1" ) @@ -4278,59 +4297,100 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if _sd_on: if not hasattr(self, "_mtp_sd_residual"): self._mtp_sd_residual = {} + if not hasattr(self, "_mtp_sd_sync_idx"): + self._mtp_sd_sync_idx = {} _sd_applied = 0 - _sd_nonzero_shift = 0 + _sd_total_shifted = 0 + _sd_sample_details = [] for _i in range(_prev_count, len(mtp_hf_tensors)): _nm_sd, _tn_sd = mtp_hf_tensors[_i] - # Only apply to fp32 tensors (the upcast - # step above may have produced fp32; if - # still bf16 here, nothing to do). - if _tn_sd.dtype != _torch_v24m.float32: + # Only fp32 MTP payload is candidate. + if _tn_sd.dtype != _torch_v16.float32: continue - _prev_res = self._mtp_sd_residual.get( - _nm_sd - ) + _r_prev = self._mtp_sd_residual.get(_nm_sd) if ( - _prev_res is not None - and _prev_res.shape == _tn_sd.shape - and _prev_res.device == _tn_sd.device + _r_prev is not None + and _r_prev.shape == _tn_sd.shape + and _r_prev.device == _tn_sd.device + and _r_prev.dtype == _tn_sd.dtype ): - _u = _tn_sd + _prev_res + _u = _tn_sd + _r_prev + _had_prev = True else: _u = _tn_sd - # Round-nearest-even to bf16, then back - # to fp32 to compute new residual. - _bf16 = _u.to(_torch_v24m.bfloat16) + _had_prev = False + # RNE fp32 -> bf16 and retrieve actual + # quantized fp32 value for residual calc. + _bf16 = _u.to(_torch_v16.bfloat16) _bb = _bf16.float() _new_res = (_u - _bb).detach().clone() - self._mtp_sd_residual[_nm_sd] = _new_res - # Count how many elements' bf16 state - # differs from the plain RNE(_tn_sd) - # baseline; this is the diagnostic - # "sigma-delta shift". + # Diagnostic: count elements whose bf16 + # representation differs from the plain + # RNE(_tn_sd) baseline (i.e. how many were + # "lifted" by accumulated residual). try: _baseline_bf16 = _tn_sd.to( - _torch_v24m.bfloat16 + _torch_v16.bfloat16 ) - _sd_nonzero_shift += int( + _shift_cnt = int( (_bf16 != _baseline_bf16) - .sum() - .item() + .sum().item() ) except Exception: - pass + _shift_cnt = -1 + self._mtp_sd_residual[_nm_sd] = _new_res + self._mtp_sd_sync_idx[_nm_sd] = ( + self._mtp_sd_sync_idx.get(_nm_sd, 0) + 1 + ) + # Replace payload tensor with sigma-delta + # bf16 version. Receiver (SGLang) will do + # its own copy_ which is now bit-exact. mtp_hf_tensors[_i] = ( _nm_sd, _bf16.contiguous(), ) _sd_applied += 1 + if _shift_cnt > 0: + _sd_total_shifted += _shift_cnt + # Per-tensor trace: first 5 tensors or + # every 10th sync, to avoid spam. + if ( + len(_sd_sample_details) < 5 + or ( + self._mtp_sd_sync_idx[_nm_sd] + % 10 == 0 + ) + ): + try: + _r_abs = float( + _new_res.abs().mean().item() + ) + _r_max = float( + _new_res.abs().max().item() + ) + except Exception: + _r_abs, _r_max = -1.0, -1.0 + _sd_sample_details.append( + "name=%s shape=%s had_prev=%s " + "sync_idx=%d shifted_elems=%d " + "residual_abs_mean=%.3e " + "residual_abs_max=%.3e" % ( + _nm_sd, + tuple(_tn_sd.shape), + str(_had_prev), + self._mtp_sd_sync_idx[_nm_sd], + _shift_cnt, + _r_abs, _r_max, + ) + ) if _sd_applied > 0: self.logger.info( - "[MTPSigmaDeltaBf16-v29] name=%s " - "applied=%d total_shifted_elems=%d", - name, _sd_applied, - _sd_nonzero_shift, + "[MTPSigmaDeltaBf16-v28] collect_name=%s " + "applied=%d total_shifted_elems=%d " + "samples=[%s]", + name, + _sd_applied, _sd_total_shifted, + " | ".join(_sd_sample_details), ) - # [MTPSigmaDeltaBf16-v29] END # [MTPWeightDeltaD15] version-to-version # abs_mean delta tracker. if not hasattr(self, "_mtp_d15_prev_abs_mean"): @@ -4660,105 +4720,163 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except Exception: pass - # [SGLangReadBackMTPv5-v31] HTTP-direct readback via RolloutCallback. - # spec_v1.log.12 proved rollout_engine._engine is None in PR#1176 - # RolloutCallback mode (pure HTTP proxy). v31 dispatches a best- - # effort POST to {controller_addr}/callback/get_mtp_weight_norm - # and treats any non-2xx (incl. 404 when the controller does not - # implement the endpoint) as "unavailable" rather than a failure. - # If the endpoint IS present, the response JSON should contain - # {"norm": , "name": } allowing H2 arbitration. + # [MTPBf16PayloadNorm-v33] Engine-side wire-truth norm. + # After the sigma-delta path above (v28-v29), entries of + # mtp_hf_tensors that correspond to fp32 master MTP params + # have been *replaced* with their bf16 RNE-cast versions + # (see "_bf16.contiguous()" at the sigma-delta tail). Those + # exact bf16 bytes are the payload that sglang's + # eagle_worker.update_weights_from_tensor .copy_()s into + # BOTH draft_model_runner.model AND target_worker.model + # (eagle_worker.py:999). So |W|_bf16_wire IS the ground + # truth for "did the weights on the wire change". No HTTP + # roundtrip needed -> immune to the MiMo /get_weights_by_name + # architectural block that killed v32's readback. if _v31_on and mtp_hf_tensors: try: - _re = self.rollout_engine - _addr = getattr(_re, "controller_addr", None) + import torch as _torch_v33 + _wire_sq = 0.0 + _wire_cnt = 0 + _wire_bf16_cnt = 0 + _wire_fp32_cnt = 0 + _first_name = None + _first_norm = None + for _nm_w, _tn_w in mtp_hf_tensors: + _tw = _tn_w.detach() + if _tw.dtype == _torch_v33.bfloat16: + _wire_bf16_cnt += 1 + elif _tw.dtype == _torch_v33.float32: + _wire_fp32_cnt += 1 + _tf = _tw.float() + _s = float((_tf * _tf).sum().item()) + _wire_sq += _s + _wire_cnt += int(_tf.numel()) + if _first_name is None: + _first_name = _nm_w + _first_norm = _s ** 0.5 + _wire_norm = _wire_sq ** 0.5 + _prev_wire = getattr(self, "_v33_prev_wire_norm", None) + self._v33_prev_wire_norm = _wire_norm + _d_wire = None + _d_wire_rel = None + if _prev_wire is not None and _wire_norm > 0: + _d_wire = abs(_wire_norm - _prev_wire) + _d_wire_rel = _d_wire / _wire_norm try: - _rk2 = ( + _rk_w = ( torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 ) except Exception: - _rk2 = 0 - if _rk2 == 0: - if _addr is None: + _rk_w = 0 + if _rk_w == 0: + _h2_wire = "UNKNOWN" + if _d_wire is not None: + if _d_wire == 0.0: + _h2_wire = "CONFIRMED-STALL" + elif _d_wire_rel is not None and _d_wire_rel < 1e-8: + _h2_wire = "SUSPECT-MICRO" + else: + _h2_wire = "REJECTED" + self.logger.info( + "[MTPBf16PayloadNorm-v33] version=%s " + "|W|_wire=%.6e (n=%d, bf16=%d fp32=%d) " + "d|W|_wire=%s d|W|_wire_rel=%s " + "first=%s first_norm=%s " + "H2_wire=%s", + str(meta.version), + _wire_norm, _wire_cnt, + _wire_bf16_cnt, _wire_fp32_cnt, + ("%.6e" % _d_wire) if _d_wire is not None else "NA", + ("%.4e" % _d_wire_rel) if _d_wire_rel is not None else "NA", + str(_first_name), + ("%.6e" % _first_norm) if _first_norm is not None else "NA", + _h2_wire, + ) + except Exception as _e_v33_wire: + try: + self.logger.warning( + "[MTPBf16PayloadNorm-v33] failed: %r", + _e_v33_wire, + ) + except Exception: + pass + # [MTPProbeLogprob-v33] Deterministic inference probe via + # /callback/get_mtp_probe. Replaces the architecturally- + # broken v32 /callback/get_mtp_weight_norm path. The probe + # posts a fixed prompt with temperature=0, top_p=1, top_k=1, + # max_new_tokens=1, return_logprob=1 to server_infos[0] and + # returns the first input_token_logprob. + if _v31_on and mtp_hf_tensors: + try: + _re_p = self.rollout_engine + _addr_p = getattr(_re_p, "controller_addr", None) + try: + _rk_p = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() else 0 + ) + except Exception: + _rk_p = 0 + if _rk_p == 0: + if _addr_p is None: self.logger.info( - "[SGLangReadBackMTPv5-v31] unavailable: " + "[MTPProbeLogprob-v33] unavailable: " "rollout_engine=%s has no controller_addr", - type(_re).__name__, + type(_re_p).__name__, ) else: - # [v32] Probe up to 3 MTP tensors and aggregate - # the biggest rel_gap seen, so a single layer- - # name mismatch does not falsely flag H2. - _probes = mtp_hf_tensors[:3] - _wire_url = ( - f"http://{_addr}/callback/" - f"get_mtp_weight_norm" - ) - _max_rel_gap = 0.0 - _any_ok = False - _status_trace = [] try: - import requests as _rq_v32 - for _pn, _pt in _probes: - _exp = float( - _pt.detach().float().norm().item() - ) - _resp = _rq_v32.post( - _wire_url, - json={"name": _pn}, - timeout=30.0, - proxies={"http": None, "https": None}, - ) - _status_trace.append( - f"{_pn}:{_resp.status_code}" - ) - if _resp.status_code == 200: - _any_ok = True - try: - _jj = _resp.json() - except Exception: - _jj = {} - _wn = float( - _jj.get("norm", float("nan")) - ) - _gap = abs(_wn - _exp) - _rg = _gap / max(_exp, 1e-12) - if _rg > _max_rel_gap: - _max_rel_gap = _rg - self.logger.info( - "[SGLangReadBackMTPv6-v32] " - "name=%s exp=%.6e wire=%.6e " - "rel_gap=%.4e", - _pn, _exp, _wn, _rg, - ) - if _any_ok: - self.logger.info( - "[SGLangReadBackMTPv6-v32] " - "aggregate max_rel_gap=%.4e " - "H2=%s trace=%s", - _max_rel_gap, - "CONFIRMED" if _max_rel_gap >= 1e-3 - else "REJECTED", - _status_trace, + import requests as _rq_v33 + _probe_url = ( + f"http://{_addr_p}/callback/" + f"get_mtp_probe" + ) + _resp = _rq_v33.post( + _probe_url, + json={"version": int(meta.version)}, + timeout=30.0, + proxies={"http": None, "https": None}, + ) + _status = _resp.status_code + _jp = {} + try: + _jp = _resp.json() + except Exception: + _jp = {} + _lp = _jp.get("logprob", None) + _server = _jp.get("server", None) + _err = _jp.get("error", None) + _prev_lp = getattr( + self, "_v33_prev_probe_logprob", None + ) + if isinstance(_lp, (int, float)): + self._v33_prev_probe_logprob = float(_lp) + _d_lp = ( + None if _prev_lp is None + else abs(float(_lp) - float(_prev_lp)) ) else: - self.logger.info( - "[SGLangReadBackMTPv6-v32] " - "unavailable: trace=%s url=%s", - _status_trace, _wire_url, - ) - except Exception as _e_rb: + _d_lp = None + self.logger.info( + "[MTPProbeLogprob-v33] version=%s " + "status=%s logprob=%s d_logprob=%s " + "server=%s err=%s", + str(meta.version), _status, + ("%.6e" % _lp) if isinstance(_lp, (int, float)) else "NA", + ("%.6e" % _d_lp) if isinstance(_d_lp, (int, float)) else "NA", + _server, _err, + ) + except Exception as _e_p: self.logger.info( - "[SGLangReadBackMTPv6-v32] http " - "failure: %r addr=%s trace=%s", - _e_rb, _addr, _status_trace, + "[MTPProbeLogprob-v33] http failure: %r", + _e_p, ) - except Exception as _e_rb_outer: + except Exception as _e_p_out: try: self.logger.warning( - "[SGLangReadBackMTPv5-v31] outer failure: %r", - _e_rb_outer, + "[MTPProbeLogprob-v33] outer failure: %r", + _e_p_out, ) except Exception: pass diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 7261331935..e9d17d0768 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -744,33 +744,32 @@ def rollout_complete(): except Exception as e: return jsonify({"error": str(e)}), 500 - # ------------------------------------------------------------ - # [v32] /callback/get_mtp_weight_norm - # ------------------------------------------------------------ - # Proxy endpoint that lets the training-side MegatronEngine - # confirm whether MTP weights it just pushed via - # /update_weights_from_tensor actually landed on the SGLang - # server. Calls SGLang's built-in /get_weights_by_name on - # the first registered inference server (rank 0) and returns - # a scalar Frobenius norm of the parameter plus its dtype. - # - # Payload: {"name": } - # Response: {"name": , "norm": , "dtype": , - # "numel": , "server": } - # - # Any transport failure or unknown param returns HTTP 200 - # with {"error": , "server": } so that - # the training side can distinguish between "endpoint - # missing" (prev versions returning 404/500) and "real H2 - # signal". + # [v32] /callback/get_mtp_weight_norm -- DEPRECATED in v33. + # SGLang MiMo does not expose /get_weights_by_name; the + # endpoint always returns {"error": "sglang status=404"}. + # Replaced by /callback/get_mtp_probe (deterministic + # inference probe). Kept as a 200+error stub so old + # MegatronEngine code does not crash on missing route. @app.route("/callback/get_mtp_weight_norm", methods=["POST"]) def get_mtp_weight_norm(): + return jsonify( + {"error": "deprecated-v33", "server": None} + ), 200 + + # ------------------------------------------------------------ + # [v33] /callback/get_mtp_probe + # ------------------------------------------------------------ + # Deterministic inference probe: sends a fixed prompt to + # server_infos[0] with temperature=0, top_p=1, top_k=1, + # max_new_tokens=1, return_logprob=1 and returns the first + # input_token_logprob. This is a functional end-to-end + # check: if the draft model's logprobs change across weight + # syncs, the weights are actually being used by the + # speculative decoder. + @app.route("/callback/get_mtp_probe", methods=["POST"]) + def get_mtp_probe(): payload = request.get_json() or {} - _name = payload.get("name") - if not _name: - return jsonify( - {"error": "missing 'name'"} - ), 200 + _version = payload.get("version", -1) _srv = None try: if not self.server_infos: @@ -781,21 +780,27 @@ def get_mtp_weight_norm(): _s0 = self.server_infos[0] _srv = f"{_s0.host}:{_s0.port}" try: - import math as _math_v32 - import requests as _rq_v32c + import requests as _rq_v33c except Exception as _e_imp: return jsonify( {"error": f"import fail: {_e_imp!r}", "server": _srv} ), 200 - _url = f"http://{_srv}/get_weights_by_name" - # truncate_size=-1 returns the full tensor as a - # (nested) python list so we can compute an exact - # Frobenius norm on the wire side. + _url = f"http://{_srv}/generate" + _probe_text = "The answer is" try: - _r = _rq_v32c.post( + _r = _rq_v33c.post( _url, - json={"name": _name, "truncate_size": -1}, + json={ + "text": _probe_text, + "sampling_params": { + "temperature": 0, + "top_p": 1, + "top_k": 1, + "max_new_tokens": 1, + }, + "return_logprob": True, + }, timeout=60.0, proxies={"http": None, "https": None}, ) @@ -804,7 +809,6 @@ def get_mtp_weight_norm(): { "error": f"http fail: {_e_http!r}", "server": _srv, - "url": _url, } ), 200 if _r.status_code != 200: @@ -812,7 +816,6 @@ def get_mtp_weight_norm(): { "error": f"sglang status={_r.status_code}", "server": _srv, - "url": _url, "body": _r.text[:400], } ), 200 @@ -825,46 +828,31 @@ def get_mtp_weight_norm(): "server": _srv, } ), 200 - # sglang may return {'parameter': ...} OR the raw list. - _param = _j - if isinstance(_j, dict): - _param = _j.get("parameter", _j) - # Flatten arbitrarily-nested lists and compute norm. - _sq = 0.0 - _numel = 0 - def _walk(_x): - nonlocal _sq, _numel - if isinstance(_x, list): - for _y in _x: - _walk(_y) - else: - try: - _v = float(_x) - except Exception: - return - _sq += _v * _v - _numel += 1 + _logprob = None try: - _walk(_param) - except Exception as _e_w: - return jsonify( - { - "error": f"walk fail: {_e_w!r}", - "server": _srv, - } - ), 200 - _norm = _sq ** 0.5 + _meta_out = _j.get("meta_info", {}) + _lp_list = _meta_out.get("input_token_logprobs", []) + if _lp_list: + _logprob = _lp_list[0] + if isinstance(_logprob, list): + _logprob = _logprob[0] + elif isinstance(_logprob, dict): + _logprob = _logprob.get( + "logprob", + _logprob.get("prob", None), + ) + except Exception: + pass return jsonify( { - "name": _name, - "norm": _norm, - "numel": _numel, + "version": _version, + "logprob": _logprob, "server": _srv, } ), 200 except Exception as _e: logger.warning( - f"[v32] get_mtp_weight_norm unexpected: {_e!r}" + f"[v33] get_mtp_probe unexpected: {_e!r}" ) return jsonify( {"error": repr(_e), "server": _srv} From 832025c568138aa6cbfa85a698fa665832b83c3b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 02:14:47 +0800 Subject: [PATCH 102/140] fix(engine): fix --- areal/engine/megatron_engine.py | 13 ++ areal/infra/controller/rollout_controller.py | 134 ++++++++++--------- 2 files changed, 86 insertions(+), 61 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 7f9faa7526..864fcdf80a 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -4295,6 +4295,19 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: except Exception: _sd_on = True if _sd_on: + # [v34] Defensive torch import: v28 SigmaDelta + # block references _torch_v16 but the original + # import was placed inside the v16 upcast + # guard `if _v16_on:`. When the operator runs + # with AREAL_MTP_FP32_BROADCAST=0 (or unset, + # default "0"), _torch_v16 is undefined and + # the Σ-Δ path raises NameError at + # `_tn_sd.dtype != _torch_v16.float32` during + # update_weights, aborting training. Importing + # torch here is always-safe (module cache) + # and restores Σ-Δ independence from the v16 + # env gate. + import torch as _torch_v16 if not hasattr(self, "_mtp_sd_residual"): self._mtp_sd_residual = {} if not hasattr(self, "_mtp_sd_sync_idx"): diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index e9d17d0768..e76fa16da8 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -744,37 +744,44 @@ def rollout_complete(): except Exception as e: return jsonify({"error": str(e)}), 500 - # [v32] /callback/get_mtp_weight_norm -- DEPRECATED in v33. - # SGLang MiMo does not expose /get_weights_by_name; the - # endpoint always returns {"error": "sglang status=404"}. - # Replaced by /callback/get_mtp_probe (deterministic - # inference probe). Kept as a 200+error stub so old - # MegatronEngine code does not crash on missing route. + # [v32] /callback/get_mtp_weight_norm (DEPRECATED STUB) + # ------------------------------------------------------------ + # v32 attempted to read MTP weights back from sglang via + # /get_weights_by_name, but MiMoForCausalLM does not override + # get_weights_by_name and the scheduler routes the call to + # tp_worker (target), not draft_worker (where MTP layers + # actually live). Architecturally unfixable from our side. + # Kept as a 200-stub so older training images calling the old + # route get a deterministic "deprecated" signal rather than + # a 404-wrapped-as-500. @app.route("/callback/get_mtp_weight_norm", methods=["POST"]) def get_mtp_weight_norm(): return jsonify( - {"error": "deprecated-v33", "server": None} + {"error": "deprecated_v32_route_use_get_mtp_probe"} ), 200 # ------------------------------------------------------------ # [v33] /callback/get_mtp_probe # ------------------------------------------------------------ - # Deterministic inference probe: sends a fixed prompt to + # Deterministic inference probe. Posts /generate to # server_infos[0] with temperature=0, top_p=1, top_k=1, - # max_new_tokens=1, return_logprob=1 and returns the first - # input_token_logprob. This is a functional end-to-end - # check: if the draft model's logprobs change across weight - # syncs, the weights are actually being used by the - # speculative decoder. + # max_new_tokens=1, return_logprob=1 on a fixed prompt, and + # returns the first input_token_logprob as a float. + # + # Payload: {"version": } + # Response: {"version": , "logprob": , + # "server": , "prompt": } @app.route("/callback/get_mtp_probe", methods=["POST"]) def get_mtp_probe(): payload = request.get_json() or {} - _version = payload.get("version", -1) + _version = payload.get("version") _srv = None + _prompt_v33 = "The answer is" try: if not self.server_infos: return jsonify( {"error": "no server_infos", + "version": _version, "server": None} ), 200 _s0 = self.server_infos[0] @@ -784,78 +791,83 @@ def get_mtp_probe(): except Exception as _e_imp: return jsonify( {"error": f"import fail: {_e_imp!r}", + "version": _version, "server": _srv} ), 200 _url = f"http://{_srv}/generate" - _probe_text = "The answer is" + _req = { + "text": _prompt_v33, + "sampling_params": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_new_tokens": 1, + }, + "return_logprob": True, + "logprob_start_len": 0, + } try: _r = _rq_v33c.post( - _url, - json={ - "text": _probe_text, - "sampling_params": { - "temperature": 0, - "top_p": 1, - "top_k": 1, - "max_new_tokens": 1, - }, - "return_logprob": True, - }, - timeout=60.0, + _url, json=_req, timeout=60.0, proxies={"http": None, "https": None}, ) except Exception as _e_http: return jsonify( - { - "error": f"http fail: {_e_http!r}", - "server": _srv, - } + {"error": f"http fail: {_e_http!r}", + "version": _version, + "server": _srv, "url": _url} ), 200 if _r.status_code != 200: return jsonify( - { - "error": f"sglang status={_r.status_code}", - "server": _srv, - "body": _r.text[:400], - } + {"error": f"sglang status={_r.status_code}", + "version": _version, + "server": _srv, "url": _url, + "body": _r.text[:400]} ), 200 try: _j = _r.json() except Exception as _e_js: return jsonify( - { - "error": f"json fail: {_e_js!r}", - "server": _srv, - } + {"error": f"json fail: {_e_js!r}", + "version": _version, + "server": _srv} + ), 200 + _item = _j if isinstance(_j, dict) else ( + _j[0] if isinstance(_j, list) and _j else {} + ) + _meta = _item.get("meta_info", {}) if isinstance(_item, dict) else {} + _itl = _meta.get("input_token_logprobs", None) + _lp = None + if isinstance(_itl, list) and _itl: + for _e in _itl: + if isinstance(_e, (list, tuple)) and _e: + _cand = _e[0] + if isinstance(_cand, (int, float)): + _lp = float(_cand) + break + elif isinstance(_e, (int, float)): + _lp = float(_e) + break + if _lp is None: + return jsonify( + {"error": "no_input_token_logprob", + "version": _version, + "server": _srv, + "meta_keys": list(_meta.keys()) if isinstance(_meta, dict) else None} ), 200 - _logprob = None - try: - _meta_out = _j.get("meta_info", {}) - _lp_list = _meta_out.get("input_token_logprobs", []) - if _lp_list: - _logprob = _lp_list[0] - if isinstance(_logprob, list): - _logprob = _logprob[0] - elif isinstance(_logprob, dict): - _logprob = _logprob.get( - "logprob", - _logprob.get("prob", None), - ) - except Exception: - pass return jsonify( - { - "version": _version, - "logprob": _logprob, - "server": _srv, - } + {"version": _version, + "logprob": _lp, + "server": _srv, + "prompt": _prompt_v33} ), 200 except Exception as _e: logger.warning( f"[v33] get_mtp_probe unexpected: {_e!r}" ) return jsonify( - {"error": repr(_e), "server": _srv} + {"error": repr(_e), "version": _version, + "server": _srv} ), 200 @app.errorhandler(Exception) From 2dff0d96644fe66208d4ceca9c6a6dedc262c08d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 09:17:23 +0800 Subject: [PATCH 103/140] fix(engine): v36 --- areal/engine/megatron_engine.py | 4 ++-- areal/infra/controller/rollout_controller.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 864fcdf80a..cc4213bfa4 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", - "v33:Bf16PayloadNorm+DeterministicProbe(AREAL_MTP_V30_DIAG)", + "v35:ProbeInputIdsFix+LongerTimeout(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -4848,7 +4848,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _resp = _rq_v33.post( _probe_url, json={"version": int(meta.version)}, - timeout=30.0, + timeout=150.0, proxies={"http": None, "https": None}, ) _status = _resp.status_code diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index e76fa16da8..4cc8335388 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -776,7 +776,8 @@ def get_mtp_probe(): payload = request.get_json() or {} _version = payload.get("version") _srv = None - _prompt_v33 = "The answer is" + _prompt_v33 = "fixed_token_seq_v35" + _probe_ids_v35 = [1, 100, 200, 300, 400, 500, 600, 700] try: if not self.server_infos: return jsonify( @@ -796,7 +797,7 @@ def get_mtp_probe(): ), 200 _url = f"http://{_srv}/generate" _req = { - "text": _prompt_v33, + "input_ids": _probe_ids_v35, "sampling_params": { "temperature": 0.0, "top_p": 1.0, @@ -808,7 +809,7 @@ def get_mtp_probe(): } try: _r = _rq_v33c.post( - _url, json=_req, timeout=60.0, + _url, json=_req, timeout=120.0, proxies={"http": None, "https": None}, ) except Exception as _e_http: From 17cf72e0f6ccddcf123c8ee9f688a9342514b7c9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 10:34:00 +0800 Subject: [PATCH 104/140] fix(engine): fix --- areal/engine/megatron_engine.py | 174 +++++++++++++++++--------------- 1 file changed, 94 insertions(+), 80 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index cc4213bfa4..93deed1aaf 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", - "v35:ProbeInputIdsFix+LongerTimeout(AREAL_MTP_V30_DIAG)", + "v36:ProbeAfterContinue(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -4814,85 +4814,6 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except Exception: pass - # [MTPProbeLogprob-v33] Deterministic inference probe via - # /callback/get_mtp_probe. Replaces the architecturally- - # broken v32 /callback/get_mtp_weight_norm path. The probe - # posts a fixed prompt with temperature=0, top_p=1, top_k=1, - # max_new_tokens=1, return_logprob=1 to server_infos[0] and - # returns the first input_token_logprob. - if _v31_on and mtp_hf_tensors: - try: - _re_p = self.rollout_engine - _addr_p = getattr(_re_p, "controller_addr", None) - try: - _rk_p = ( - torch.distributed.get_rank() - if torch.distributed.is_initialized() else 0 - ) - except Exception: - _rk_p = 0 - if _rk_p == 0: - if _addr_p is None: - self.logger.info( - "[MTPProbeLogprob-v33] unavailable: " - "rollout_engine=%s has no controller_addr", - type(_re_p).__name__, - ) - else: - try: - import requests as _rq_v33 - _probe_url = ( - f"http://{_addr_p}/callback/" - f"get_mtp_probe" - ) - _resp = _rq_v33.post( - _probe_url, - json={"version": int(meta.version)}, - timeout=150.0, - proxies={"http": None, "https": None}, - ) - _status = _resp.status_code - _jp = {} - try: - _jp = _resp.json() - except Exception: - _jp = {} - _lp = _jp.get("logprob", None) - _server = _jp.get("server", None) - _err = _jp.get("error", None) - _prev_lp = getattr( - self, "_v33_prev_probe_logprob", None - ) - if isinstance(_lp, (int, float)): - self._v33_prev_probe_logprob = float(_lp) - _d_lp = ( - None if _prev_lp is None - else abs(float(_lp) - float(_prev_lp)) - ) - else: - _d_lp = None - self.logger.info( - "[MTPProbeLogprob-v33] version=%s " - "status=%s logprob=%s d_logprob=%s " - "server=%s err=%s", - str(meta.version), _status, - ("%.6e" % _lp) if isinstance(_lp, (int, float)) else "NA", - ("%.6e" % _d_lp) if isinstance(_d_lp, (int, float)) else "NA", - _server, _err, - ) - except Exception as _e_p: - self.logger.info( - "[MTPProbeLogprob-v33] http failure: %r", - _e_p, - ) - except Exception as _e_p_out: - try: - self.logger.warning( - "[MTPProbeLogprob-v33] outer failure: %r", - _e_p_out, - ) - except Exception: - pass if mtp_hf_tensors: # [v5-F3] Compute norms for ALL tensors (was: only first 5). # [v5-F5] Track prev norm per-tensor to surface drift direction @@ -5370,6 +5291,99 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: current_platform.synchronize() dist.barrier(group=self.cpu_group) + # [MTPProbeLogprob-v36] Deterministic inference probe AFTER + # continue_generation. v33/v35 ran the probe inside + # update_weights_from_distributed BEFORE the weight-update + # RPCs, which meant SGLang was PAUSED (not serving) and the + # /generate request always timed out. v36 moves the probe + # to after continue_generation, when SGLang is live and + # serving requests. This is the only point in the training + # loop where the updated weights are guaranteed to be loaded + # AND the inference server is accepting requests. + try: + import os as _os_v36 + _v36_on = _os_v36.environ.get("AREAL_MTP_V30_DIAG", "1") == "1" + except Exception: + _v36_on = False + if _v36_on: + try: + _re_v36 = self.rollout_engine + _addr_v36 = getattr(_re_v36, "controller_addr", None) + try: + _rk_v36 = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() else 0 + ) + except Exception: + _rk_v36 = 0 + if _rk_v36 == 0: + if _addr_v36 is None: + self.logger.info( + "[MTPProbeLogprob-v36] unavailable: " + "rollout_engine=%s has no controller_addr", + type(_re_v36).__name__, + ) + else: + try: + import requests as _rq_v36 + _probe_url = ( + f"http://{_addr_v36}/callback/" + f"get_mtp_probe" + ) + _resp_v36 = _rq_v36.post( + _probe_url, + json={"version": self._weight_version}, + timeout=150.0, + proxies={"http": None, "https": None}, + ) + _status_v36 = _resp_v36.status_code + _jp_v36 = {} + try: + _jp_v36 = _resp_v36.json() + except Exception: + _jp_v36 = {} + _lp_v36 = _jp_v36.get("logprob", None) + _srv_v36 = _jp_v36.get("server", None) + _err_v36 = _jp_v36.get("error", None) + _prev_lp_v36 = getattr( + self, "_v36_prev_probe_logprob", None + ) + if isinstance(_lp_v36, (int, float)): + self._v36_prev_probe_logprob = float( + _lp_v36 + ) + _d_lp_v36 = ( + None if _prev_lp_v36 is None + else abs( + float(_lp_v36) + - float(_prev_lp_v36) + ) + ) + else: + _d_lp_v36 = None + self.logger.info( + "[MTPProbeLogprob-v36] version=%s " + "status=%s logprob=%s d_logprob=%s " + "server=%s err=%s", + self._weight_version, + _status_v36, + ("%.6e" % _lp_v36) if isinstance(_lp_v36, (int, float)) else "NA", + ("%.6e" % _d_lp_v36) if isinstance(_d_lp_v36, (int, float)) else "NA", + _srv_v36, _err_v36, + ) + except Exception as _e_v36: + self.logger.info( + "[MTPProbeLogprob-v36] http failure: %r", + _e_v36, + ) + except Exception as _e_v36_out: + try: + self.logger.warning( + "[MTPProbeLogprob-v36] outer failure: %r", + _e_v36_out, + ) + except Exception: + pass self.logger.info( f"[DiagUW] _update_weights_from_distributed FULLY COMPLETED " f"in {_diag_time.time() - _diag_t0:.3f}s" From 6e32b210e5273cc5ae366d6f701e1899f9c6695b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 14:58:56 +0800 Subject: [PATCH 105/140] fix(engine): fix --- areal/engine/megatron_engine.py | 146 +++++++++++++++++++------------- 1 file changed, 88 insertions(+), 58 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 93deed1aaf..866c33ed8e 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", - "v36:ProbeAfterContinue(AREAL_MTP_V30_DIAG)", + "v37b:ProbeStageTraceback(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -5291,96 +5291,126 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: current_platform.synchronize() dist.barrier(group=self.cpu_group) - # [MTPProbeLogprob-v36] Deterministic inference probe AFTER - # continue_generation. v33/v35 ran the probe inside - # update_weights_from_distributed BEFORE the weight-update - # RPCs, which meant SGLang was PAUSED (not serving) and the - # /generate request always timed out. v36 moves the probe - # to after continue_generation, when SGLang is live and - # serving requests. This is the only point in the training - # loop where the updated weights are guaranteed to be loaded - # AND the inference server is accepting requests. + # [MTPProbeLogprob-v37b] Deterministic inference probe AFTER + # continue_generation, with per-stage try/except + traceback. + # + # v36 failed universally with + # AttributeError: 'MegatronPPOActor' object has no attribute + # '_weight_version' + # because MegatronEngine exposes self._version + get_version(), + # never _weight_version. v37b fixes the attribute and also + # wraps every line of the probe in a per-stage try/except so + # any future failure logs traceback.format_exc() AND a stage + # tag identifying the exact raise site. try: - import os as _os_v36 - _v36_on = _os_v36.environ.get("AREAL_MTP_V30_DIAG", "1") == "1" + import os as _os_v37b + _v37b_on = _os_v37b.environ.get("AREAL_MTP_V30_DIAG", "1") == "1" except Exception: - _v36_on = False - if _v36_on: + _v37b_on = False + if _v37b_on: + _stage_v37b = "enter" try: - _re_v36 = self.rollout_engine - _addr_v36 = getattr(_re_v36, "controller_addr", None) + import traceback as _tb_v37b + _stage_v37b = "get_rollout_engine" + _re_v37b = self.rollout_engine + _stage_v37b = "getattr_controller_addr" + _addr_v37b = getattr(_re_v37b, "controller_addr", None) + _stage_v37b = "get_rank" try: - _rk_v36 = ( + _rk_v37b = ( torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 ) except Exception: - _rk_v36 = 0 - if _rk_v36 == 0: - if _addr_v36 is None: + _rk_v37b = 0 + if _rk_v37b == 0: + if _addr_v37b is None: self.logger.info( - "[MTPProbeLogprob-v36] unavailable: " + "[MTPProbeLogprob-v37b] unavailable: " "rollout_engine=%s has no controller_addr", - type(_re_v36).__name__, + type(_re_v37b).__name__, ) else: try: - import requests as _rq_v36 - _probe_url = ( - f"http://{_addr_v36}/callback/" + _stage_v37b = "import_requests" + import requests as _rq_v37b + _stage_v37b = "build_url" + _probe_url_v37b = ( + f"http://{_addr_v37b}/callback/" f"get_mtp_probe" ) - _resp_v36 = _rq_v36.post( - _probe_url, - json={"version": self._weight_version}, + _stage_v37b = "build_version_int" + _ver_v37b = int(self.get_version()) + _stage_v37b = "http_post" + _resp_v37b = _rq_v37b.post( + _probe_url_v37b, + json={"version": _ver_v37b}, timeout=150.0, proxies={"http": None, "https": None}, ) - _status_v36 = _resp_v36.status_code - _jp_v36 = {} + _stage_v37b = "get_status" + _status_v37b = _resp_v37b.status_code + _stage_v37b = "parse_json" + _jp_v37b = {} try: - _jp_v36 = _resp_v36.json() + _jp_v37b = _resp_v37b.json() except Exception: - _jp_v36 = {} - _lp_v36 = _jp_v36.get("logprob", None) - _srv_v36 = _jp_v36.get("server", None) - _err_v36 = _jp_v36.get("error", None) - _prev_lp_v36 = getattr( - self, "_v36_prev_probe_logprob", None + _jp_v37b = {} + _stage_v37b = "extract_fields" + _lp_v37b = _jp_v37b.get("logprob", None) + _srv_v37b = _jp_v37b.get("server", None) + _err_v37b = _jp_v37b.get("error", None) + _stage_v37b = "get_prev_lp" + _prev_lp_v37b = getattr( + self, "_v37b_prev_probe_logprob", None ) - if isinstance(_lp_v36, (int, float)): - self._v36_prev_probe_logprob = float( - _lp_v36 - ) - _d_lp_v36 = ( - None if _prev_lp_v36 is None + _stage_v37b = "compute_d_lp" + if isinstance(_lp_v37b, (int, float)): + _d_lp_v37b = ( + None if _prev_lp_v37b is None else abs( - float(_lp_v36) - - float(_prev_lp_v36) + float(_lp_v37b) + - float(_prev_lp_v37b) ) ) else: - _d_lp_v36 = None + _d_lp_v37b = None + _stage_v37b = "set_prev_lp_attr" + if isinstance(_lp_v37b, (int, float)): + self._v37b_prev_probe_logprob = float( + _lp_v37b + ) + _stage_v37b = "logger_info_success" self.logger.info( - "[MTPProbeLogprob-v36] version=%s " + "[MTPProbeLogprob-v37b] version=%s " "status=%s logprob=%s d_logprob=%s " "server=%s err=%s", - self._weight_version, - _status_v36, - ("%.6e" % _lp_v36) if isinstance(_lp_v36, (int, float)) else "NA", - ("%.6e" % _d_lp_v36) if isinstance(_d_lp_v36, (int, float)) else "NA", - _srv_v36, _err_v36, + _ver_v37b, + _status_v37b, + ("%.6e" % _lp_v37b) if isinstance(_lp_v37b, (int, float)) else "NA", + ("%.6e" % _d_lp_v37b) if isinstance(_d_lp_v37b, (int, float)) else "NA", + _srv_v37b, _err_v37b, ) - except Exception as _e_v36: + except Exception as _e_v37b: + try: + _tb_str_v37b = _tb_v37b.format_exc() + except Exception: + _tb_str_v37b = "" self.logger.info( - "[MTPProbeLogprob-v36] http failure: %r", - _e_v36, + "[MTPProbeLogprob-v37b] inner failure " + "at stage=%s exc=%r\nTRACEBACK:\n%s", + _stage_v37b, _e_v37b, _tb_str_v37b, ) - except Exception as _e_v36_out: + except Exception as _e_v37b_out: + try: + _tb_out_v37b = _tb_v37b.format_exc() + except Exception: + _tb_out_v37b = "" try: self.logger.warning( - "[MTPProbeLogprob-v36] outer failure: %r", - _e_v36_out, + "[MTPProbeLogprob-v37b] outer failure at " + "stage=%s exc=%r\nTRACEBACK:\n%s", + _stage_v37b, _e_v37b_out, _tb_out_v37b, ) except Exception: pass From c5e7d39002398d3a32e609381cca12b1d120f031 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 15:27:50 +0800 Subject: [PATCH 106/140] feat(controller/engine): add callback log --- areal/engine/megatron_engine.py | 142 ++++++++++++++++++- areal/infra/controller/rollout_controller.py | 135 ++++++++++++++++++ 2 files changed, 276 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 866c33ed8e..0328c5feee 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", - "v37b:ProbeStageTraceback(AREAL_MTP_V30_DIAG)", + "v38:DraftOutputProbe(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -5414,6 +5414,146 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except Exception: pass + # [DraftOutputProbe-v38] Draft+target OUTPUT SEQUENCE probe. + # v37b only reads input_token_logprobs[0] which is pure target. + # v38 drives /generate with max_new_tokens=32, top_k=1, T=0 + # and records output_ids + output logprobs + any meta_info + # spec/accept fields, so we can see draft+MTP head effects. + # Per-stage try/except + traceback for robustness. + try: + import os as _os_v38 + _v38_on = _os_v38.environ.get("AREAL_MTP_V30_DIAG", "1") == "1" + except Exception: + _v38_on = False + if _v38_on: + _stage_v38 = "enter" + try: + import traceback as _tb_v38 + _stage_v38 = "get_rollout_engine" + _re_v38 = self.rollout_engine + _stage_v38 = "getattr_controller_addr" + _addr_v38 = getattr(_re_v38, "controller_addr", None) + _stage_v38 = "get_rank" + try: + _rk_v38 = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() else 0 + ) + except Exception: + _rk_v38 = 0 + if _rk_v38 == 0 and _addr_v38 is not None: + try: + _stage_v38 = "import_requests" + import requests as _rq_v38 + _stage_v38 = "build_url" + _probe_url_v38 = ( + f"http://{_addr_v38}/callback/" + f"get_draft_probe" + ) + _stage_v38 = "build_version_int" + _ver_v38 = int(self.get_version()) + _stage_v38 = "http_post" + _resp_v38 = _rq_v38.post( + _probe_url_v38, + json={"version": _ver_v38}, + timeout=180.0, + proxies={"http": None, "https": None}, + ) + _stage_v38 = "get_status" + _status_v38 = _resp_v38.status_code + _stage_v38 = "parse_json" + _jp_v38 = {} + try: + _jp_v38 = _resp_v38.json() + except Exception: + _jp_v38 = {} + _stage_v38 = "extract_fields" + _oi_v38 = _jp_v38.get("out_ids_first8", None) + _oi_len_v38 = _jp_v38.get("out_ids_len", None) + _olps_v38 = _jp_v38.get("out_lps_first4", None) + _last_lp_v38 = _jp_v38.get("last_lp", None) + _sum_lp_v38 = _jp_v38.get("sum_lp", None) + _otext_v38 = _jp_v38.get("out_text_head", None) + _mkeys_v38 = _jp_v38.get("meta_keys", None) + _specf_v38 = _jp_v38.get("spec_fields", None) + _err_v38 = _jp_v38.get("error", None) + _stage_v38 = "compute_d_fields" + _prev_oi_v38 = getattr( + self, "_v38_prev_out_ids", None) + _prev_last_lp_v38 = getattr( + self, "_v38_prev_last_lp", None) + _prev_sum_lp_v38 = getattr( + self, "_v38_prev_sum_lp", None) + _d_oi_v38 = None + if (isinstance(_oi_v38, list) + and isinstance(_prev_oi_v38, list) + and len(_oi_v38) == len(_prev_oi_v38)): + _d_oi_v38 = sum( + 1 for _a, _b in zip(_oi_v38, _prev_oi_v38) + if _a != _b + ) + _d_last_lp_v38 = None + if (isinstance(_last_lp_v38, (int, float)) + and isinstance(_prev_last_lp_v38, (int, float))): + _d_last_lp_v38 = abs( + float(_last_lp_v38) + - float(_prev_last_lp_v38) + ) + _d_sum_lp_v38 = None + if (isinstance(_sum_lp_v38, (int, float)) + and isinstance(_prev_sum_lp_v38, (int, float))): + _d_sum_lp_v38 = abs( + float(_sum_lp_v38) + - float(_prev_sum_lp_v38) + ) + _stage_v38 = "set_prev_attrs" + if isinstance(_oi_v38, list): + self._v38_prev_out_ids = list(_oi_v38) + if isinstance(_last_lp_v38, (int, float)): + self._v38_prev_last_lp = float(_last_lp_v38) + if isinstance(_sum_lp_v38, (int, float)): + self._v38_prev_sum_lp = float(_sum_lp_v38) + _stage_v38 = "logger_info_success" + self.logger.info( + "[DraftOutputProbe-v38] version=%s " + "status=%s out_ids_len=%s out_ids=%s " + "d_out_ids_hamming=%s last_lp=%s " + "d_last_lp=%s sum_lp=%s d_sum_lp=%s " + "out_text_head=%r meta_keys=%s " + "spec_fields=%s err=%s", + _ver_v38, _status_v38, + _oi_len_v38, _oi_v38, + _d_oi_v38, + ("%.6e" % _last_lp_v38) if isinstance(_last_lp_v38, (int, float)) else "NA", + ("%.6e" % _d_last_lp_v38) if isinstance(_d_last_lp_v38, (int, float)) else "NA", + ("%.6e" % _sum_lp_v38) if isinstance(_sum_lp_v38, (int, float)) else "NA", + ("%.6e" % _d_sum_lp_v38) if isinstance(_d_sum_lp_v38, (int, float)) else "NA", + _otext_v38, _mkeys_v38, + _specf_v38, _err_v38, + ) + except Exception as _e_v38: + try: + _tb_str_v38 = _tb_v38.format_exc() + except Exception: + _tb_str_v38 = "" + self.logger.info( + "[DraftOutputProbe-v38] inner failure " + "at stage=%s exc=%r\nTRACEBACK:\n%s", + _stage_v38, _e_v38, _tb_str_v38, + ) + except Exception as _e_v38_out: + try: + _tb_out_v38 = _tb_v38.format_exc() + except Exception: + _tb_out_v38 = "" + try: + self.logger.warning( + "[DraftOutputProbe-v38] outer failure at " + "stage=%s exc=%r\nTRACEBACK:\n%s", + _stage_v38, _e_v38_out, _tb_out_v38, + ) + except Exception: + pass self.logger.info( f"[DiagUW] _update_weights_from_distributed FULLY COMPLETED " f"in {_diag_time.time() - _diag_t0:.3f}s" diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 4cc8335388..3fd1f1a006 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -871,6 +871,141 @@ def get_mtp_probe(): "server": _srv} ), 200 + # ------------------------------------------------------------ + # [v38] /callback/get_draft_probe + # ------------------------------------------------------------ + # Output-sequence probe: unlike v33 which reads + # input_token_logprobs[0] (target-model only), this probe + # drives /generate with max_new_tokens=32, temperature=0, + # top_k=1, return_logprob=1 and returns: + # - output_ids (first 8 generated token ids) + # - output_lps (per-position logprob of generated tokens) + # - last_lp (last position logprob) + # - meta_keys (raw meta_info keys, for field discovery) + # - spec_fields (any meta_info key containing 'spec' or + # 'accept' or 'verify' or 'draft') + # When draft+MTP heads change behavior, output_ids or + # output_lps MUST change. If target is frozen but heads + # drift, the joint sequence changes => H3 confirmed. + @app.route("/callback/get_draft_probe", methods=["POST"]) + def get_draft_probe_v38(): + payload = request.get_json() or {} + _version = payload.get("version") + _srv = None + _probe_ids_v38 = [1, 100, 200, 300, 400, 500, 600, 700] + try: + if not self.server_infos: + return jsonify( + {"error": "no server_infos", + "version": _version, + "server": None} + ), 200 + _s0 = self.server_infos[0] + _srv = f"{_s0.host}:{_s0.port}" + try: + import requests as _rq_v38 + except Exception as _e_imp38: + return jsonify( + {"error": f"import fail: {_e_imp38!r}", + "version": _version, + "server": _srv} + ), 200 + _url = f"http://{_srv}/generate" + _req = { + "input_ids": _probe_ids_v38, + "sampling_params": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_new_tokens": 32, + }, + "return_logprob": True, + "logprob_start_len": 0, + } + try: + _r = _rq_v38.post( + _url, json=_req, timeout=120.0, + proxies={"http": None, "https": None}, + ) + except Exception as _e_http38: + return jsonify( + {"error": f"http fail: {_e_http38!r}", + "version": _version, + "server": _srv, "url": _url} + ), 200 + if _r.status_code != 200: + return jsonify( + {"error": f"sglang status={_r.status_code}", + "version": _version, + "server": _srv, + "body": _r.text[:400]} + ), 200 + try: + _j = _r.json() + except Exception as _e_js38: + return jsonify( + {"error": f"json fail: {_e_js38!r}", + "version": _version, + "server": _srv} + ), 200 + _item = _j if isinstance(_j, dict) else ( + _j[0] if isinstance(_j, list) and _j else {} + ) + _meta = _item.get("meta_info", {}) if isinstance(_item, dict) else {} + _out_text = _item.get("text", None) if isinstance(_item, dict) else None + _otl = _meta.get("output_token_logprobs", None) if isinstance(_meta, dict) else None + _out_ids = [] + _out_lps = [] + if isinstance(_otl, list): + for _e in _otl[:32]: + _lp_i = None + _id_i = None + if isinstance(_e, (list, tuple)) and len(_e) >= 2: + _cand_lp = _e[0] + _cand_id = _e[1] + if isinstance(_cand_lp, (int, float)): + _lp_i = float(_cand_lp) + if isinstance(_cand_id, int): + _id_i = int(_cand_id) + if _id_i is not None: + _out_ids.append(_id_i) + if _lp_i is not None: + _out_lps.append(_lp_i) + _last_lp = _out_lps[-1] if _out_lps else None + _sum_lp = sum(_out_lps) if _out_lps else None + _meta_keys = list(_meta.keys()) if isinstance(_meta, dict) else [] + _spec_fields = {} + if isinstance(_meta, dict): + for _k, _v in _meta.items(): + _kl = str(_k).lower() + if ("spec" in _kl or "accept" in _kl + or "verify" in _kl or "draft" in _kl + or "jump" in _kl): + try: + _spec_fields[str(_k)] = _v + except Exception: + _spec_fields[str(_k)] = repr(_v) + return jsonify( + {"version": _version, + "server": _srv, + "out_ids_first8": _out_ids[:8], + "out_ids_len": len(_out_ids), + "out_lps_first4": _out_lps[:4], + "last_lp": _last_lp, + "sum_lp": _sum_lp, + "out_text_head": (_out_text[:60] if isinstance(_out_text, str) else None), + "meta_keys": _meta_keys, + "spec_fields": _spec_fields} + ), 200 + except Exception as _e38: + logger.warning( + f"[v38] get_draft_probe unexpected: {_e38!r}" + ) + return jsonify( + {"error": repr(_e38), "version": _version, + "server": _srv} + ), 200 + @app.errorhandler(Exception) def handle_error(e): logger.error( From 5e3ead153fe8932dca8b2a3d7c312803e21de28f Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 16:07:19 +0800 Subject: [PATCH 107/140] feat(controller/engine): add more --- areal/engine/megatron_engine.py | 152 ++++++++++++++++- areal/infra/controller/rollout_controller.py | 171 +++++++++++++++++++ 2 files changed, 322 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 0328c5feee..8f204410fc 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", - "v38:DraftOutputProbe(AREAL_MTP_V30_DIAG)", + "v39:LongStochProbe+PerLayerMTPNorm(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -5554,6 +5554,156 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except Exception: pass + # [DraftSpecTrend-v39] Long + stochastic probes. Plus a + # per-MTP-layer norm scan so heads' individual drift is + # visible instead of aggregated |W_MTP|. + try: + import os as _os_v39 + _v39_on = _os_v39.environ.get("AREAL_MTP_V30_DIAG", "1") == "1" + except Exception: + _v39_on = False + if _v39_on: + _stage_v39 = "enter" + try: + import traceback as _tb_v39 + _stage_v39 = "get_rollout_engine" + _re_v39 = self.rollout_engine + _addr_v39 = getattr(_re_v39, "controller_addr", None) + try: + _rk_v39 = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() else 0 + ) + except Exception: + _rk_v39 = 0 + # --- (a) Per-MTP-layer fp32 norm scan --- + _stage_v39 = "per_layer_norm" + try: + if mtp_hf_tensors: + import torch as _torch_v39 + _layer_norms = {} + for _n, _f in mtp_hf_tensors.items(): + if not hasattr(_f, "dtype"): + continue + try: + if _f.dtype != _torch_v39.float32: + _fc = _f.detach().to(_torch_v39.float32) + else: + _fc = _f.detach() + _nrm = float(_fc.float().norm().item()) + # group by "model.mtp_layers.{i}." + _key = None + _parts = _n.split(".") + if len(_parts) >= 3 and _parts[0] == "model" and _parts[1] == "mtp_layers": + _key = f"mtp_layer_{_parts[2]}" + else: + _key = "other_mtp" + _layer_norms.setdefault(_key, 0.0) + _layer_norms[_key] = (_layer_norms[_key] ** 2 + _nrm ** 2) ** 0.5 + except Exception: + pass + _prev = getattr(self, "_v39_prev_layer_norms", None) + _rel = {} + if isinstance(_prev, dict): + for _k, _v in _layer_norms.items(): + _pv = _prev.get(_k, None) + if isinstance(_pv, (int, float)) and _pv > 0: + _rel[_k] = abs(_v - _pv) / _pv + self._v39_prev_layer_norms = dict(_layer_norms) + if _rk_v39 == 0: + self.logger.info( + "[PerLayerMTPNorm-v39] version=%s " + "norms=%s d_rel=%s", + int(self.get_version()), + {_k: ("%.6e" % _v) for _k, _v in _layer_norms.items()}, + {_k: ("%.3e" % _v) for _k, _v in _rel.items()}, + ) + except Exception as _e_pln: + if _rk_v39 == 0: + try: + self.logger.info( + "[PerLayerMTPNorm-v39] failure: %r\nTRACEBACK:\n%s", + _e_pln, _tb_v39.format_exc(), + ) + except Exception: + pass + # --- (b) Long probe --- + _stage_v39 = "long_probe" + if _rk_v39 == 0 and _addr_v39 is not None: + try: + import requests as _rq_l + _ver = int(self.get_version()) + _r_l = _rq_l.post( + f"http://{_addr_v39}/callback/get_draft_probe_long", + json={"version": _ver}, + timeout=240.0, + proxies={"http": None, "https": None}, + ) + _j_l = _r_l.json() if _r_l.status_code == 200 else {} + self.logger.info( + "[DraftSpecTrend-v39 long] version=%s " + "status=%s out_ids_len=%s " + "first16=%s last16=%s sum_lp=%s mid_lp=%s " + "first_lps=%s last_lps=%s spec=%s err=%s", + _ver, _r_l.status_code, + _j_l.get("out_ids_len"), + _j_l.get("out_ids_first16"), + _j_l.get("out_ids_last16"), + _j_l.get("sum_lp"), + _j_l.get("mid_lp"), + _j_l.get("out_lps_first4"), + _j_l.get("out_lps_last4"), + _j_l.get("spec_fields"), + _j_l.get("error"), + ) + except Exception as _e_l: + self.logger.info( + "[DraftSpecTrend-v39 long] failure: %r\nTRACEBACK:\n%s", + _e_l, _tb_v39.format_exc(), + ) + # --- (c) Stochastic probe --- + _stage_v39 = "stoch_probe" + if _rk_v39 == 0 and _addr_v39 is not None: + try: + import requests as _rq_s + _ver = int(self.get_version()) + _r_s = _rq_s.post( + f"http://{_addr_v39}/callback/get_draft_probe_stoch", + json={"version": _ver}, + timeout=300.0, + proxies={"http": None, "https": None}, + ) + _j_s = _r_s.json() if _r_s.status_code == 200 else {} + self.logger.info( + "[DraftSpecTrend-v39 stoch] version=%s " + "status=%s n_ok=%s " + "spec_accept_rate_stats=%s " + "spec_accept_length_stats=%s " + "spec_accept_rate_samples=%s " + "spec_accept_length_samples=%s " + "histograms=%s err=%s", + _ver, _r_s.status_code, + _j_s.get("n_ok"), + _j_s.get("spec_accept_rate_stats"), + _j_s.get("spec_accept_length_stats"), + _j_s.get("spec_accept_rate_samples"), + _j_s.get("spec_accept_length_samples"), + _j_s.get("histograms"), + _j_s.get("error"), + ) + except Exception as _e_s: + self.logger.info( + "[DraftSpecTrend-v39 stoch] failure: %r\nTRACEBACK:\n%s", + _e_s, _tb_v39.format_exc(), + ) + except Exception as _e_v39_out: + try: + self.logger.warning( + "[DraftSpecTrend-v39] outer failure at stage=%s exc=%r\nTRACEBACK:\n%s", + _stage_v39, _e_v39_out, _tb_v39.format_exc(), + ) + except Exception: + pass self.logger.info( f"[DiagUW] _update_weights_from_distributed FULLY COMPLETED " f"in {_diag_time.time() - _diag_t0:.3f}s" diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 3fd1f1a006..529b7a9e69 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -1006,6 +1006,177 @@ def get_draft_probe_v38(): "server": _srv} ), 200 + # ------------------------------------------------------------ + # [v39] /callback/get_draft_probe_long + # ------------------------------------------------------------ + # Like v38 but with max_new_tokens=128 and collects per-position + # output_token_logprobs + spec_accept_histogram to measure + # draft drift on longer sequences. + @app.route("/callback/get_draft_probe_long", methods=["POST"]) + def get_draft_probe_long_v39(): + payload = request.get_json() or {} + _version = payload.get("version") + _srv = None + _probe_ids = [1, 100, 200, 300, 400, 500, 600, 700] + try: + if not self.server_infos: + return jsonify( + {"error": "no server_infos", "version": _version} + ), 200 + _s0 = self.server_infos[0] + _srv = f"{_s0.host}:{_s0.port}" + import requests as _rq + _url = f"http://{_srv}/generate" + _req = { + "input_ids": _probe_ids, + "sampling_params": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_new_tokens": 128, + }, + "return_logprob": True, + "logprob_start_len": 0, + } + _r = _rq.post( + _url, json=_req, timeout=180.0, + proxies={"http": None, "https": None}, + ) + if _r.status_code != 200: + return jsonify( + {"error": f"status={_r.status_code}", + "version": _version, + "server": _srv, "body": _r.text[:400]} + ), 200 + _j = _r.json() + _item = _j if isinstance(_j, dict) else ( + _j[0] if isinstance(_j, list) and _j else {} + ) + _meta = _item.get("meta_info", {}) if isinstance(_item, dict) else {} + _otl = _meta.get("output_token_logprobs", None) + _out_ids, _out_lps = [], [] + if isinstance(_otl, list): + for _e in _otl[:128]: + if isinstance(_e, (list, tuple)) and len(_e) >= 2: + _lp_i = _e[0] + _id_i = _e[1] + if isinstance(_id_i, int): + _out_ids.append(int(_id_i)) + if isinstance(_lp_i, (int, float)): + _out_lps.append(float(_lp_i)) + _sum_lp = sum(_out_lps) if _out_lps else None + _mid_lp = _out_lps[len(_out_lps)//2] if _out_lps else None + _specf = {} + for _k, _v in (_meta or {}).items(): + _kl = str(_k).lower() + if ("spec" in _kl or "accept" in _kl + or "verify" in _kl or "draft" in _kl + or "jump" in _kl): + _specf[str(_k)] = _v + return jsonify( + {"version": _version, + "server": _srv, + "out_ids_first16": _out_ids[:16], + "out_ids_last16": _out_ids[-16:] if len(_out_ids) >= 16 else _out_ids, + "out_ids_len": len(_out_ids), + "sum_lp": _sum_lp, + "mid_lp": _mid_lp, + "out_lps_first4": _out_lps[:4], + "out_lps_last4": _out_lps[-4:], + "spec_fields": _specf} + ), 200 + except Exception as _e: + logger.warning(f"[v39] long probe unexpected: {_e!r}") + return jsonify( + {"error": repr(_e), "version": _version} + ), 200 + + # ------------------------------------------------------------ + # [v39] /callback/get_draft_probe_stoch + # ------------------------------------------------------------ + # Stochastic probe: temperature=0.8, top_p=0.95, 8 samples; + # reports mean/std/min/max of spec_accept_length and + # spec_accept_rate across samples to expose draft drift on + # realistic distributions (matching training rollout). + @app.route("/callback/get_draft_probe_stoch", methods=["POST"]) + def get_draft_probe_stoch_v39(): + payload = request.get_json() or {} + _version = payload.get("version") + _srv = None + _probe_ids = [1, 100, 200, 300, 400, 500, 600, 700] + _N = 8 + try: + if not self.server_infos: + return jsonify( + {"error": "no server_infos", "version": _version} + ), 200 + _s0 = self.server_infos[0] + _srv = f"{_s0.host}:{_s0.port}" + import requests as _rq + _url = f"http://{_srv}/generate" + _req = { + "input_ids": _probe_ids, + "sampling_params": { + "temperature": 0.8, + "top_p": 0.95, + "top_k": 0, + "max_new_tokens": 128, + }, + "return_logprob": False, + } + _ar_list, _al_list, _hist_list = [], [], [] + _n_ok = 0 + for _i in range(_N): + try: + _r = _rq.post( + _url, json=_req, timeout=120.0, + proxies={"http": None, "https": None}, + ) + if _r.status_code != 200: + continue + _j = _r.json() + _item = _j if isinstance(_j, dict) else ( + _j[0] if isinstance(_j, list) and _j else {} + ) + _meta = _item.get("meta_info", {}) if isinstance(_item, dict) else {} + _ar = _meta.get("spec_accept_rate", None) + _al = _meta.get("spec_accept_length", None) + _h = _meta.get("spec_accept_histogram", None) + if isinstance(_ar, (int, float)): + _ar_list.append(float(_ar)) + if isinstance(_al, (int, float)): + _al_list.append(float(_al)) + if isinstance(_h, list): + _hist_list.append(list(_h)) + _n_ok += 1 + except Exception: + continue + def _ms(xs): + if not xs: + return {"n": 0, "mean": None, "std": None, + "min": None, "max": None} + _n = len(xs) + _m = sum(xs) / _n + _v = sum((x - _m) ** 2 for x in xs) / _n + _s = _v ** 0.5 + return {"n": _n, "mean": _m, "std": _s, + "min": min(xs), "max": max(xs)} + return jsonify( + {"version": _version, + "server": _srv, + "n_ok": _n_ok, + "spec_accept_rate_stats": _ms(_ar_list), + "spec_accept_length_stats": _ms(_al_list), + "spec_accept_rate_samples": _ar_list, + "spec_accept_length_samples": _al_list, + "histograms": _hist_list} + ), 200 + except Exception as _e: + logger.warning(f"[v39] stoch probe unexpected: {_e!r}") + return jsonify( + {"error": repr(_e), "version": _version} + ), 200 + @app.errorhandler(Exception) def handle_error(e): logger.error( From 2eb714dfef46dc1f9cff5e7b196a44cd0a79aa6c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 16:44:44 +0800 Subject: [PATCH 108/140] feat(rollout): refactor code --- areal/engine/megatron_engine.py | 71 +++++++++++++++++++- areal/infra/controller/rollout_controller.py | 25 +++++-- 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 8f204410fc..d75fed508e 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", - "v39:LongStochProbe+PerLayerMTPNorm(AREAL_MTP_V30_DIAG)", + "v40:AcceptHistTrendAccumulator+FixPerLayerIterAndStochProbe(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -5582,7 +5582,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if mtp_hf_tensors: import torch as _torch_v39 _layer_norms = {} - for _n, _f in mtp_hf_tensors.items(): + for _n, _f in mtp_hf_tensors: if not hasattr(_f, "dtype"): continue try: @@ -5656,6 +5656,73 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _j_l.get("spec_fields"), _j_l.get("error"), ) + # [v40] accept-histogram trend accumulator + try: + _hist_v40 = _j_l.get("spec_fields", {}) or {} + _h_v40 = _hist_v40.get("spec_accept_histogram", None) + _al_v40 = _hist_v40.get("spec_accept_length", None) + _ar_v40 = _hist_v40.get("spec_accept_rate", None) + _trail_v40 = getattr(self, "_v40_long_hist_trail", None) + if _trail_v40 is None: + _trail_v40 = [] + self._v40_long_hist_trail = _trail_v40 + _trail_v40.append( + { + "v": int(self.get_version()), + "h": (list(_h_v40) if isinstance(_h_v40, list) else None), + "al": (float(_al_v40) if isinstance(_al_v40, (int, float)) else None), + "ar": (float(_ar_v40) if isinstance(_ar_v40, (int, float)) else None), + } + ) + # cap trail at 64 + if len(_trail_v40) > 64: + del _trail_v40[0: len(_trail_v40) - 64] + # emit compact trend line + _al_seq = [x["al"] for x in _trail_v40] + _ar_seq = [x["ar"] for x in _trail_v40] + _b2_seq = [ + (x["h"][2] if isinstance(x["h"], list) and len(x["h"]) > 2 else None) + for x in _trail_v40 + ] + _b3_seq = [ + (x["h"][3] if isinstance(x["h"], list) and len(x["h"]) > 3 else None) + for x in _trail_v40 + ] + # monotonic-decline detector (strict <= with at least one strict <) + def _mono_decline(_seq): + _xs = [x for x in _seq if isinstance(x, (int, float))] + if len(_xs) < 3: + return None + _lt = all(_xs[_i] <= _xs[_i - 1] for _i in range(1, len(_xs))) + _any_strict = any(_xs[_i] < _xs[_i - 1] for _i in range(1, len(_xs))) + return bool(_lt and _any_strict) + self.logger.info( + "[AcceptHistTrend-v40] n_versions=%d " + "al_seq=%s ar_seq=%s bucket_accept_len3=%s " + "bucket_accept_len4=%s al_mono_decline=%s " + "ar_mono_decline=%s", + len(_trail_v40), + [ + (None if _v is None else round(_v, 4)) + for _v in _al_seq + ], + [ + (None if _v is None else round(_v, 4)) + for _v in _ar_seq + ], + _b2_seq, + _b3_seq, + _mono_decline(_al_seq), + _mono_decline(_ar_seq), + ) + except Exception as _e_v40_trail: + try: + self.logger.info( + "[AcceptHistTrend-v40] accumulator " + "failure: %r", _e_v40_trail, + ) + except Exception: + pass except Exception as _e_l: self.logger.info( "[DraftSpecTrend-v39 long] failure: %r\nTRACEBACK:\n%s", diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 529b7a9e69..cf6820fd6e 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -1104,7 +1104,7 @@ def get_draft_probe_stoch_v39(): _version = payload.get("version") _srv = None _probe_ids = [1, 100, 200, 300, 400, 500, 600, 700] - _N = 8 + _N = 4 # [v40] halved to keep total timing under 240s try: if not self.server_infos: return jsonify( @@ -1117,10 +1117,11 @@ def get_draft_probe_stoch_v39(): _req = { "input_ids": _probe_ids, "sampling_params": { + # [v40] drop "top_k": 0 — SGLang rejects 0; + # convention is -1 or omit. Keep T=0.8 / top_p=0.95. "temperature": 0.8, "top_p": 0.95, - "top_k": 0, - "max_new_tokens": 128, + "max_new_tokens": 64, }, "return_logprob": False, } @@ -1133,6 +1134,15 @@ def get_draft_probe_stoch_v39(): proxies={"http": None, "https": None}, ) if _r.status_code != 200: + try: + logger.warning( + "[v40] stoch sub-sample %d " + "status=%d body_head=%r", + _i, _r.status_code, + (_r.text[:200] if hasattr(_r, "text") else None), + ) + except Exception: + pass continue _j = _r.json() _item = _j if isinstance(_j, dict) else ( @@ -1149,7 +1159,14 @@ def get_draft_probe_stoch_v39(): if isinstance(_h, list): _hist_list.append(list(_h)) _n_ok += 1 - except Exception: + except Exception as _e_samp_v40: + try: + logger.warning( + "[v40] stoch sub-sample %d failure: %r", + _i, _e_samp_v40, + ) + except Exception: + pass continue def _ms(xs): if not xs: From 20593da7c86eb5e75f0ebfb976cbfc7183b57470 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 17:29:36 +0800 Subject: [PATCH 109/140] feat(engine): patch --- areal/engine/megatron_engine.py | 71 +++++++++++++++++++- areal/engine/sglang_remote.py | 51 ++++++++++++++ areal/infra/controller/rollout_controller.py | 58 ++++++++++++++++ 3 files changed, 179 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index d75fed508e..09a23e5d51 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", - "v40:AcceptHistTrendAccumulator+FixPerLayerIterAndStochProbe(AREAL_MTP_V30_DIAG)", + "v41:RealPromptProbe+ServerInfoProbe+ProductionWindow(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -5723,6 +5723,75 @@ def _mono_decline(_seq): ) except Exception: pass + # [v41] realistic-prompt probe reuse + try: + import requests as _rq_rp41 + _rp_ids = None + try: + from areal.engine.sglang_remote import ( + SGLangRemote as _SGLR41, + ) + _rp_ids = getattr( + _SGLR41, "_v41_last_prompt_ids", None + ) + except Exception: + _rp_ids = None + if isinstance(_rp_ids, list) and len(_rp_ids) >= 4: + _rp_resp = _rq_rp41.post( + f"http://{_addr_v39}/callback/get_draft_probe_long", + json={"version": _ver, "input_ids_override": _rp_ids[:256]}, + timeout=240.0, + proxies={"http": None, "https": None}, + ) + _rp_j = _rp_resp.json() if _rp_resp.status_code == 200 else {} + self.logger.info( + "[RealPromptProbe-v41] version=%s status=%s " + "probe_ids_len=%s probe_ids_head=%s " + "out_ids_len=%s spec=%s", + _ver, _rp_resp.status_code, + _rp_j.get("probe_ids_len"), + _rp_j.get("probe_ids_head"), + _rp_j.get("out_ids_len"), + _rp_j.get("spec_fields"), + ) + else: + self.logger.info( + "[RealPromptProbe-v41] version=%s skipped " + "(no production prompt cached yet)", + _ver, + ) + except Exception as _e_rp41: + try: + self.logger.info( + "[RealPromptProbe-v41] failure: %r", + _e_rp41, + ) + except Exception: + pass + # [v41] server-info probe + try: + import requests as _rq_si41 + _si_resp = _rq_si41.post( + f"http://{_addr_v39}/callback/get_server_info_v41", + json={"version": _ver}, + timeout=60.0, + proxies={"http": None, "https": None}, + ) + _si_j = _si_resp.json() if _si_resp.status_code == 200 else {} + self.logger.info( + "[ServerInfoProbe-v41] version=%s status=%s " + "servers=%s", + _ver, _si_resp.status_code, + _si_j.get("servers"), + ) + except Exception as _e_si41: + try: + self.logger.info( + "[ServerInfoProbe-v41] failure: %r", + _e_si41, + ) + except Exception: + pass except Exception as _e_l: self.logger.info( "[DraftSpecTrend-v39 long] failure: %r\nTRACEBACK:\n%s", diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 6880843a02..fb5fad3c0a 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -157,6 +157,57 @@ def parse_generation_response( rate, self._spec_dec_ema_short, self._spec_dec_ema_long, self._spec_dec_rate_count, self._spec_dec_rate_sum / self._spec_dec_rate_count, ) + # [v41] production window snapshot (256-wide) + try: + _win_size_v41 = 256 + if not hasattr(self, '_v41_win_rates'): + self._v41_win_rates = [] + self._v41_win_completions = [] + self._v41_win_prev_start = 0 + self._v41_win_rates.append(float(rate)) + _ctok_v41 = meta_info.get('completion_tokens', None) + if isinstance(_ctok_v41, (int, float)): + self._v41_win_completions.append(int(_ctok_v41)) + if len(self._v41_win_rates) >= _win_size_v41: + _wr = self._v41_win_rates + _wc = self._v41_win_completions + _m_r = sum(_wr) / len(_wr) + _m_c = (sum(_wc) / len(_wc)) if _wc else None + _vr = sum((x - _m_r) ** 2 for x in _wr) / len(_wr) + _sr = _vr ** 0.5 + _gn = int(self._spec_dec_rate_count) + logger.info( + '[ProductionAcceptWindow-v41] ' + 'win=[%d..%d] n=%d rate_mean=%.4f ' + 'rate_std=%.4f completion_mean=%s ' + 'rate_min=%.4f rate_max=%.4f', + _gn - len(_wr) + 1, _gn, + len(_wr), _m_r, _sr, + ('%.2f' % _m_c) if _m_c is not None else 'NA', + min(_wr), max(_wr), + ) + self._v41_win_rates = [] + self._v41_win_completions = [] + except Exception as _e_v41w: + pass + # [v41] stash last-seen prompt IDs for probe reuse + try: + _pt_v41 = meta_info.get('prompt_tokens', None) + _pti_v41 = None + # SGLang doesn't echo input_ids in meta; instead, + # peek the outer `response` payload for an 'input_ids' + # field if the caller set return_logprob=True. + _resp_in_v41 = response.get('input_ids', None) + if isinstance(_resp_in_v41, list) and _resp_in_v41: + _pti_v41 = [int(x) for x in _resp_in_v41[:256] + if isinstance(x, int)] + if _pti_v41 is not None: + type(self)._v41_last_prompt_ids = _pti_v41 + type(self)._v41_last_prompt_len = ( + int(_pt_v41) if isinstance(_pt_v41, (int, float)) else None + ) + except Exception as _e_v41p: + pass except Exception as _e: pass if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index cf6820fd6e..dde92f6805 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -1018,6 +1018,12 @@ def get_draft_probe_long_v39(): _version = payload.get("version") _srv = None _probe_ids = [1, 100, 200, 300, 400, 500, 600, 700] + # [v41] honor caller override for realistic-prompt probe + _override_v41 = payload.get("input_ids_override") + if isinstance(_override_v41, list) and _override_v41: + _probe_ids = [int(x) for x in _override_v41 if isinstance(x, int)] + if not _probe_ids: + _probe_ids = [1, 100, 200, 300, 400, 500, 600, 700] try: if not self.server_infos: return jsonify( @@ -1076,6 +1082,8 @@ def get_draft_probe_long_v39(): return jsonify( {"version": _version, "server": _srv, + "probe_ids_len": len(_probe_ids), + "probe_ids_head": _probe_ids[:8], "out_ids_first16": _out_ids[:16], "out_ids_last16": _out_ids[-16:] if len(_out_ids) >= 16 else _out_ids, "out_ids_len": len(_out_ids), @@ -1194,6 +1202,56 @@ def _ms(xs): {"error": repr(_e), "version": _version} ), 200 + # ------------------------------------------------------------ + # [v41] /callback/get_server_info_v41 + # Hit SGLang /get_server_info to pull cumulative spec counters + # at the server level, independent of per-request. + # ------------------------------------------------------------ + @app.route("/callback/get_server_info_v41", methods=["POST"]) + def get_server_info_v41(): + payload = request.get_json() or {} + _version = payload.get("version") + try: + if not self.server_infos: + return jsonify( + {"error": "no server_infos", "version": _version} + ), 200 + import requests as _rq41 + _acc = [] + for _s in self.server_infos: + _srv = f"{_s.host}:{_s.port}" + try: + _r = _rq41.get( + f"http://{_srv}/get_server_info", + timeout=30.0, + proxies={"http": None, "https": None}, + ) + if _r.status_code == 200: + _j = _r.json() if _r.headers.get( + "Content-Type", "" + ).startswith("application/json") else {} + # extract only fields of interest to keep + # log line short + _keep = {} + for _k, _v in (_j or {}).items(): + _kl = str(_k).lower() + if ("spec" in _kl or "draft" in _kl + or "accept" in _kl or "token" in _kl + or "version" in _kl or "weight" in _kl): + _keep[str(_k)] = _v + _acc.append({"server": _srv, "info": _keep}) + else: + _acc.append({"server": _srv, "status": _r.status_code, + "body_head": _r.text[:200]}) + except Exception as _e41s: + _acc.append({"server": _srv, "err": repr(_e41s)}) + return jsonify({"version": _version, "servers": _acc}), 200 + except Exception as _e41: + logger.warning(f"[v41] server_info unexpected: {_e41!r}") + return jsonify( + {"error": repr(_e41), "version": _version} + ), 200 + @app.errorhandler(Exception) def handle_error(e): logger.error( From f47606d51ddd189b3f83a984a2b619ae61066b72 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 18:33:45 +0800 Subject: [PATCH 110/140] fix(engine): adapter --- areal/engine/megatron_engine.py | 81 +++++++++++++++++++++++++++++++-- areal/engine/sglang_remote.py | 41 +++++++++-------- 2 files changed, 100 insertions(+), 22 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 09a23e5d51..dbb42515d3 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", - "v41:RealPromptProbe+ServerInfoProbe+ProductionWindow(AREAL_MTP_V30_DIAG)", + "v42:FixRealPromptProbeImport+RequestSidePromptCapture+MTPWeightHash(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -3456,6 +3456,56 @@ def _serialize_mtp_tensors_for_update( f"tensor_dtypes={_tensor_dtypes}, " f"tensor_sizes_bytes={_tensor_sizes}" ) + # [MTPWeightHash-v42] Fingerprint each MTP tensor about to be + # serialised. We hash up to 1024 fp32 values with a + # per-tensor xor-rotate so the 64-bit digest changes on ANY + # modification, without paying for a full-tensor reduction. + # The digest stream is monotonic only if the target-side + # weights are actually being refreshed between versions, + # which lets us discriminate H6 (target/draft sync skew) + # from H5 (policy-phase drift) when accept-rate dips. + try: + import torch as _torch_v42 + _v42_ver = None + try: + _v42_ver = int(self.get_version()) + except Exception: + _v42_ver = None + _v42_digests = [] + for _v42_n, _v42_t in mtp_hf_tensors: + try: + _flat = _v42_t.detach().reshape(-1) + _k = min(1024, int(_flat.numel())) + if _k > 0: + _sl = _flat[:_k].float().contiguous().cpu() + _bytes = _sl.numpy().tobytes() + _h = 0 + for _b in _bytes: + _h = ((_h * 1315423911) ^ int(_b)) & ((1 << 64) - 1) + _s = float(_sl.sum().item()) + _a = float(_sl.abs().mean().item()) + else: + _h, _s, _a = 0, 0.0, 0.0 + _v42_digests.append( + (_v42_n, _h, _s, _a, + tuple(_v42_t.shape), str(_v42_t.dtype)) + ) + except Exception as _e_hash_one: + _v42_digests.append( + (_v42_n, None, None, None, None, + repr(_e_hash_one)) + ) + self.logger.info( + "[MTPWeightHash-v42] version=%s n_tensors=%d digests=%s", + _v42_ver, len(_v42_digests), _v42_digests, + ) + except Exception as _e_hash_all: + try: + self.logger.info( + "[MTPWeightHash-v42] probe failure: %r", _e_hash_all, + ) + except Exception: + pass # [MTPSerializeSendMTP-v26] Sample first 8 values of each MTP # tensor so we can prove the actual bytes placed into the # SGLang IPC payload. The earlier MTPSendPreBcast-v25 probe @@ -5727,14 +5777,37 @@ def _mono_decline(_seq): try: import requests as _rq_rp41 _rp_ids = None + # [v42] fix: the class in sglang_remote.py is + # `SGLangBackend`, NOT `SGLangRemote`. The v41 + # import was silently caught by the outer except + # and left _rp_ids=None for every version, so + # RealPromptProbe was unreachable. Import the + # correct symbol now. try: from areal.engine.sglang_remote import ( - SGLangRemote as _SGLR41, + SGLangBackend as _SGLR42, ) _rp_ids = getattr( - _SGLR41, "_v41_last_prompt_ids", None + _SGLR42, "_v41_last_prompt_ids", None + ) + _rp_len = getattr( + _SGLR42, "_v41_last_prompt_len", None + ) + _rp_ver = getattr( + _SGLR42, "_v42_last_prompt_version", + None, + ) + self.logger.info( + "[RealPromptProbeCapture-v42] " + "present=%s len=%s prompt_ver=%s", + _rp_ids is not None, + _rp_len, _rp_ver, + ) + except Exception as _e_imp42: + self.logger.info( + "[RealPromptProbeCapture-v42] " + "import-fail: %r", _e_imp42, ) - except Exception: _rp_ids = None if isinstance(_rp_ids, list) and len(_rp_ids) >= 4: _rp_resp = _rq_rp41.post( diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index fb5fad3c0a..22194146f0 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -75,6 +75,22 @@ def build_generation_request( "stream": False, } + # [v42] request-side prompt capture (replaces broken + # response-side `response.get("input_ids")` in + # parse_generation_response, which never fires because + # SGLang does not echo input_ids in its /generate reply). + try: + _pids_v42 = req.input_ids + if isinstance(_pids_v42, (list, tuple)) and len(_pids_v42) >= 4: + _pids_head_v42 = [int(x) for x in list(_pids_v42)[:256] + if isinstance(x, (int, float))] + if _pids_head_v42: + type(self)._v41_last_prompt_ids = _pids_head_v42 + type(self)._v41_last_prompt_len = len(_pids_v42) + type(self)._v42_last_prompt_version = int(version) + except Exception: + pass + # Add return_routed_experts to payload if set if req.metadata.get("return_routed_experts", False): payload["return_routed_experts"] = True @@ -190,24 +206,13 @@ def parse_generation_response( self._v41_win_completions = [] except Exception as _e_v41w: pass - # [v41] stash last-seen prompt IDs for probe reuse - try: - _pt_v41 = meta_info.get('prompt_tokens', None) - _pti_v41 = None - # SGLang doesn't echo input_ids in meta; instead, - # peek the outer `response` payload for an 'input_ids' - # field if the caller set return_logprob=True. - _resp_in_v41 = response.get('input_ids', None) - if isinstance(_resp_in_v41, list) and _resp_in_v41: - _pti_v41 = [int(x) for x in _resp_in_v41[:256] - if isinstance(x, int)] - if _pti_v41 is not None: - type(self)._v41_last_prompt_ids = _pti_v41 - type(self)._v41_last_prompt_len = ( - int(_pt_v41) if isinstance(_pt_v41, (int, float)) else None - ) - except Exception as _e_v41p: - pass + # [v42] response-side capture removed. Prompt IDs are + # now captured in build_generation_request() because + # SGLang does not echo input_ids in the /generate + # response and the old response.get('input_ids') lookup + # never fired, leaving RealPromptProbe permanently + # starved. Nothing to do here. + pass except Exception as _e: pass if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): From bb77aca19a089dc0f9d35cdb8b63842173e72f77 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 19:27:09 +0800 Subject: [PATCH 111/140] refactor(engine): fix --- areal/engine/megatron_engine.py | 153 +++++++++++++++++++------------- areal/engine/sglang_remote.py | 33 +++---- 2 files changed, 103 insertions(+), 83 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index dbb42515d3..8c86040bf4 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -200,7 +200,7 @@ def __init__(self, config: TrainEngineConfig): "v14:LRScaleGuard+WeightDeltaGuard", "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", - "v42:FixRealPromptProbeImport+RequestSidePromptCapture+MTPWeightHash(AREAL_MTP_V30_DIAG)", + "v43:FixedLongProbe+MTPWeightHashDelta+CrossProcFix(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -3499,6 +3499,47 @@ def _serialize_mtp_tensors_for_update( "[MTPWeightHash-v42] version=%s n_tensors=%d digests=%s", _v42_ver, len(_v42_digests), _v42_digests, ) + # [v43] delta detector across versions + try: + _cur_map_v43 = {} + for _d in _v42_digests: + if isinstance(_d, tuple) and len(_d) >= 2: + _cur_map_v43[_d[0]] = _d[1] + _prev_map_v43 = getattr(self, "_v43_prev_digests", None) + if isinstance(_prev_map_v43, dict): + _changed_v43 = [] + _same_v43 = [] + for _n, _h in _cur_map_v43.items(): + _ph = _prev_map_v43.get(_n) + if _ph is None: + continue + if _ph == _h: + _same_v43.append(_n) + else: + _changed_v43.append(_n) + self.logger.info( + "[MTPWeightHashDelta-v43] version=%s " + "n_total=%d n_changed=%d n_same=%d " + "changed=%s same=%s", + _v42_ver, len(_cur_map_v43), + len(_changed_v43), len(_same_v43), + _changed_v43, _same_v43, + ) + else: + self.logger.info( + "[MTPWeightHashDelta-v43] version=%s baseline " + "(no prior digest map)", + _v42_ver, + ) + self._v43_prev_digests = _cur_map_v43 + except Exception as _e_delta_v43: + try: + self.logger.info( + "[MTPWeightHashDelta-v43] failure: %r", + _e_delta_v43, + ) + except Exception: + pass except Exception as _e_hash_all: try: self.logger.info( @@ -5773,71 +5814,59 @@ def _mono_decline(_seq): ) except Exception: pass - # [v41] realistic-prompt probe reuse + # [v43] FixedLongProbe: deterministic 128-token + # synthetic prompt. Same IDs every version, so + # AR is a pure function of (target + draft) + # weights. Discriminator: + # production AR dip + FixedLong AR flat -> H5 + # production AR dip + FixedLong AR dip -> H6 + # The probe reuses the existing controller + # endpoint /callback/get_draft_probe_long via + # input_ids_override. try: - import requests as _rq_rp41 - _rp_ids = None - # [v42] fix: the class in sglang_remote.py is - # `SGLangBackend`, NOT `SGLangRemote`. The v41 - # import was silently caught by the outer except - # and left _rp_ids=None for every version, so - # RealPromptProbe was unreachable. Import the - # correct symbol now. + import requests as _rq_fl43 + _fl_ids_v43 = [ + int((i * 37 + 5009) % 50000) for i in range(128) + ] + _fl_resp = _rq_fl43.post( + f"http://{_addr_v39}/callback/get_draft_probe_long", + json={"version": _ver, + "input_ids_override": _fl_ids_v43}, + timeout=240.0, + proxies={"http": None, "https": None}, + ) + _fl_j = _fl_resp.json() if _fl_resp.status_code == 200 else {} + _fl_spec = _fl_j.get("spec_fields") or {} + _fl_rate = None try: - from areal.engine.sglang_remote import ( - SGLangBackend as _SGLR42, - ) - _rp_ids = getattr( - _SGLR42, "_v41_last_prompt_ids", None - ) - _rp_len = getattr( - _SGLR42, "_v41_last_prompt_len", None - ) - _rp_ver = getattr( - _SGLR42, "_v42_last_prompt_version", - None, - ) - self.logger.info( - "[RealPromptProbeCapture-v42] " - "present=%s len=%s prompt_ver=%s", - _rp_ids is not None, - _rp_len, _rp_ver, - ) - except Exception as _e_imp42: - self.logger.info( - "[RealPromptProbeCapture-v42] " - "import-fail: %r", _e_imp42, - ) - _rp_ids = None - if isinstance(_rp_ids, list) and len(_rp_ids) >= 4: - _rp_resp = _rq_rp41.post( - f"http://{_addr_v39}/callback/get_draft_probe_long", - json={"version": _ver, "input_ids_override": _rp_ids[:256]}, - timeout=240.0, - proxies={"http": None, "https": None}, - ) - _rp_j = _rp_resp.json() if _rp_resp.status_code == 200 else {} - self.logger.info( - "[RealPromptProbe-v41] version=%s status=%s " - "probe_ids_len=%s probe_ids_head=%s " - "out_ids_len=%s spec=%s", - _ver, _rp_resp.status_code, - _rp_j.get("probe_ids_len"), - _rp_j.get("probe_ids_head"), - _rp_j.get("out_ids_len"), - _rp_j.get("spec_fields"), - ) - else: - self.logger.info( - "[RealPromptProbe-v41] version=%s skipped " - "(no production prompt cached yet)", - _ver, - ) - except Exception as _e_rp41: + _atn = _fl_spec.get("spec_accept_token_num") + _dtn = _fl_spec.get("spec_draft_token_num") + if (isinstance(_atn, (int, float)) + and isinstance(_dtn, (int, float)) + and _dtn > 0): + _fl_rate = float(_atn) / float(_dtn) + except Exception: + _fl_rate = None + self.logger.info( + "[FixedLongProbe-v43] version=%s status=%s " + "probe_ids_len=%s probe_ids_head=%s " + "out_ids_len=%s sum_lp=%s mid_lp=%s " + "spec_accept_rate=%s spec=%s", + _ver, _fl_resp.status_code, + _fl_j.get("probe_ids_len"), + _fl_j.get("probe_ids_head"), + _fl_j.get("out_ids_len"), + _fl_j.get("sum_lp"), + _fl_j.get("mid_lp"), + ("%.4f" % _fl_rate) if _fl_rate is not None + else "NA", + _fl_spec, + ) + except Exception as _e_fl43: try: self.logger.info( - "[RealPromptProbe-v41] failure: %r", - _e_rp41, + "[FixedLongProbe-v43] failure: %r", + _e_fl43, ) except Exception: pass diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 22194146f0..24839ce72b 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -75,21 +75,14 @@ def build_generation_request( "stream": False, } - # [v42] request-side prompt capture (replaces broken - # response-side `response.get("input_ids")` in - # parse_generation_response, which never fires because - # SGLang does not echo input_ids in its /generate reply). - try: - _pids_v42 = req.input_ids - if isinstance(_pids_v42, (list, tuple)) and len(_pids_v42) >= 4: - _pids_head_v42 = [int(x) for x in list(_pids_v42)[:256] - if isinstance(x, (int, float))] - if _pids_head_v42: - type(self)._v41_last_prompt_ids = _pids_head_v42 - type(self)._v41_last_prompt_len = len(_pids_v42) - type(self)._v42_last_prompt_version = int(version) - except Exception: - pass + # [v43] request-side prompt capture removed. + # It lived in the inference-worker process while the + # corresponding trainer-side read lives in the + # MegatronEngine process, so the class attribute never + # crossed the process boundary (log.26 confirmed + # present=False on every version). v43 replaces the + # whole probe with a process-local synthetic FixedLong + # prompt in megatron_engine.py; nothing to stash here. # Add return_routed_experts to payload if set if req.metadata.get("return_routed_experts", False): @@ -206,12 +199,10 @@ def parse_generation_response( self._v41_win_completions = [] except Exception as _e_v41w: pass - # [v42] response-side capture removed. Prompt IDs are - # now captured in build_generation_request() because - # SGLang does not echo input_ids in the /generate - # response and the old response.get('input_ids') lookup - # never fired, leaving RealPromptProbe permanently - # starved. Nothing to do here. + # [v43] prompt capture removed from both request + # and response paths (cross-process state bug). + # Downstream probe now uses a deterministic + # synthetic prompt instead of a cached real one. pass except Exception as _e: pass From 45927a50c4185b61ec8a30c20361ca66eb744a41 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 20:21:25 +0800 Subject: [PATCH 112/140] feat(engine): fix engine --- areal/engine/megatron_engine.py | 98 +++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 8c86040bf4..3b1843482a 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -201,6 +201,7 @@ def __init__(self, config: TrainEngineConfig): "v16:MTPSerializeFp32Upcast(AREAL_MTP_FP32_BROADCAST)", "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", "v43:FixedLongProbe+MTPWeightHashDelta+CrossProcFix(AREAL_MTP_V30_DIAG)", + "v44:MTPSrcHash+RepeatFixedLongProbe(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -4292,6 +4293,50 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: if _collect_mtp_for_draft and _mtp_param is not None: _mtp_model_name = self.hf_config.model_type _prev_count = len(mtp_hf_tensors) + # [MTPSrcHash-v44] hash Megatron-side collected + # tensor BEFORE convert_to_hf so we can tell if + # hidden_layernorm.weight (digest identical across + # all v43 versions) is frozen at Megatron source + # (training/grad issue) or during HF export path. + try: + import torch as _torch_v44s + _v44s_ver = None + try: + _v44s_ver = int(self.get_version()) + except Exception: + _v44s_ver = None + _v44s_flat = _mtp_param.detach().reshape(-1) + _v44s_k = min(1024, int(_v44s_flat.numel())) + if _v44s_k > 0: + _v44s_sl = ( + _v44s_flat[:_v44s_k].float().contiguous().cpu() + ) + _v44s_bytes = _v44s_sl.numpy().tobytes() + _v44s_h = 0 + for _b in _v44s_bytes: + _v44s_h = ( + (_v44s_h * 1315423911) ^ int(_b) + ) & ((1 << 64) - 1) + _v44s_sum = float(_v44s_sl.sum().item()) + _v44s_am = float(_v44s_sl.abs().mean().item()) + else: + _v44s_h, _v44s_sum, _v44s_am = 0, 0.0, 0.0 + self.logger.info( + "[MTPSrcHash-v44] version=%s name=%s " + "src_dtype=%s src_shape=%s hash=%s " + "sum_first1024=%s abs_mean_first1024=%s", + _v44s_ver, name, + str(_mtp_param.dtype), + tuple(_mtp_param.shape), + _v44s_h, _v44s_sum, _v44s_am, + ) + except Exception as _e_v44s: + try: + self.logger.info( + "[MTPSrcHash-v44] failure: %r", _e_v44s, + ) + except Exception: + pass mtp_hf_tensors.extend( convert_to_hf( self.tf_config, @@ -5870,6 +5915,59 @@ def _mono_decline(_seq): ) except Exception: pass + # [RepeatFixedLongProbe-v44] fire the SAME + # deterministic 128-token prompt again to + # measure within-version stochastic variance. + # If run-1 vs run-2 differ wildly, the AR + # dip is temperature/KV-cache noise, not a + # weight-state shift. + try: + import requests as _rq_rfl44 + _rfl_ids_v44 = [ + int((i * 37 + 5009) % 50000) for i in range(128) + ] + _rfl_resp = _rq_rfl44.post( + f"http://{_addr_v39}/callback/get_draft_probe_long", + json={"version": _ver, + "input_ids_override": _rfl_ids_v44}, + timeout=240.0, + proxies={"http": None, "https": None}, + ) + _rfl_j = ( + _rfl_resp.json() + if _rfl_resp.status_code == 200 else {} + ) + _rfl_spec = _rfl_j.get("spec_fields") or {} + _rfl_rate = None + try: + _atn2 = _rfl_spec.get("spec_accept_token_num") + _dtn2 = _rfl_spec.get("spec_draft_token_num") + if (isinstance(_atn2, (int, float)) + and isinstance(_dtn2, (int, float)) + and _dtn2 > 0): + _rfl_rate = float(_atn2) / float(_dtn2) + except Exception: + _rfl_rate = None + self.logger.info( + "[RepeatFixedLongProbe-v44] version=%s " + "status=%s out_ids_len=%s sum_lp=%s " + "mid_lp=%s spec_accept_rate=%s spec=%s", + _ver, _rfl_resp.status_code, + _rfl_j.get("out_ids_len"), + _rfl_j.get("sum_lp"), + _rfl_j.get("mid_lp"), + ("%.4f" % _rfl_rate) if _rfl_rate is not None + else "NA", + _rfl_spec, + ) + except Exception as _e_rfl44: + try: + self.logger.info( + "[RepeatFixedLongProbe-v44] failure: %r", + _e_rfl44, + ) + except Exception: + pass # [v41] server-info probe try: import requests as _rq_si41 From 7e39669f8d16a78ed294b8033524f455625389ba Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 21:01:08 +0800 Subject: [PATCH 113/140] feat(engine): verify --- areal/engine/megatron_engine.py | 128 ++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 3b1843482a..405cc18994 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -202,6 +202,7 @@ def __init__(self, config: TrainEngineConfig): "v28:MTPSigmaDeltaBf16(AREAL_MTP_SIGMA_DELTA_BF16)", "v43:FixedLongProbe+MTPWeightHashDelta+CrossProcFix(AREAL_MTP_V30_DIAG)", "v44:MTPSrcHash+RepeatFixedLongProbe(AREAL_MTP_V30_DIAG)", + "v45:MTPULPGap+DraftIPCStall(AREAL_MTP_V30_DIAG)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -3533,6 +3534,45 @@ def _serialize_mtp_tensors_for_update( _v42_ver, ) self._v43_prev_digests = _cur_map_v43 + # [MTPDraftIPCStall-v45] cumulative stall count + # per tensor. If a hash equals the previous + # version's hash, the draft worker saw a + # bit-exact copy: stall_count += 1, else reset. + try: + if not hasattr(self, "_v45_stall_count"): + self._v45_stall_count = {} + _prev_cur = getattr( + self, "_v45_last_cur_map", None + ) + _v45_rows = [] + for _n_s, _h_s in _cur_map_v43.items(): + if (_prev_cur is not None + and _prev_cur.get(_n_s) == _h_s): + self._v45_stall_count[_n_s] = ( + self._v45_stall_count.get(_n_s, 0) + 1 + ) + else: + self._v45_stall_count[_n_s] = 0 + _v45_rows.append( + (_n_s, self._v45_stall_count[_n_s]) + ) + self._v45_last_cur_map = dict(_cur_map_v43) + _v45_rows.sort(key=lambda r: -r[1]) + self.logger.info( + "[MTPDraftIPCStall-v45] version=%s " + "max_stall=%s top5_stalled=%s", + _v42_ver, + (_v45_rows[0][1] if _v45_rows else None), + _v45_rows[:5], + ) + except Exception as _e_v45_s: + try: + self.logger.info( + "[MTPDraftIPCStall-v45] failure: %r", + _e_v45_s, + ) + except Exception: + pass except Exception as _e_delta_v43: try: self.logger.info( @@ -4337,6 +4377,94 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except Exception: pass + # [MTPULPGap-v45] Quantify the bf16-ULP gap on + # the Megatron-side fp32 master. For each MTP + # tensor, we round fp32->bf16 and compare to + # the previous sync's rounded bf16. If no + # element flipped even one bf16 ULP, the + # downstream draft sees a bit-exact copy of + # the PREVIOUS version, regardless of what + # fp32 master has been doing. This nails the + # "hidden_layernorm.weight is frozen" obs + # (log.27/28) as a pure quantization ceiling. + try: + import torch as _torch_v45 + _v45_ver = None + try: + _v45_ver = int(self.get_version()) + except Exception: + _v45_ver = None + if not hasattr(self, "_v45_prev_bf16"): + self._v45_prev_bf16 = {} + if not hasattr(self, "_v45_prev_fp32"): + self._v45_prev_fp32 = {} + _v45_t_fp32 = _mtp_param.detach().float() + _v45_bf16 = _v45_t_fp32.to(_torch_v45.bfloat16) + _v45_prev_b = self._v45_prev_bf16.get(name) + _v45_prev_f = self._v45_prev_fp32.get(name) + if (_v45_prev_b is not None + and _v45_prev_b.shape == _v45_bf16.shape): + _v45_flips = int( + (_v45_bf16 != _v45_prev_b).sum().item() + ) + else: + _v45_flips = -1 + if (_v45_prev_f is not None + and _v45_prev_f.shape == _v45_t_fp32.shape): + _v45_d = (_v45_t_fp32 - _v45_prev_f).abs() + _v45_drift_max = float(_v45_d.max().item()) + _v45_drift_mean = float(_v45_d.mean().item()) + else: + _v45_drift_max = -1.0 + _v45_drift_mean = -1.0 + # bf16 ULP estimator for the tensor's + # dominant magnitude: ULP = 2^(e-7) where + # 2^e <= |x|max < 2^(e+1). For |x|max=0 + # (zero tensor) default 2^-133 (denormal). + _v45_amax = float( + _v45_t_fp32.abs().max().item() + ) + if _v45_amax > 0: + import math as _m_v45 + _v45_e = _m_v45.floor( + _m_v45.log2(_v45_amax) + ) + _v45_ulp_max = 2.0 ** (_v45_e - 7) + else: + _v45_ulp_max = float('nan') + # Estimated syncs until the next ULP flip + # on the largest-magnitude element: ULP / + # per-element drift. + if (_v45_drift_max > 0 + and _v45_ulp_max == _v45_ulp_max): + _v45_eta = _v45_ulp_max / _v45_drift_max + else: + _v45_eta = -1.0 + self.logger.info( + "[MTPULPGap-v45] version=%s name=%s " + "shape=%s amax=%.6e bf16_ulp_at_amax=%.6e " + "drift_abs_max=%.6e drift_abs_mean=%.6e " + "bf16_flips_vs_prev=%s " + "eta_syncs_to_next_flip=%.2f", + _v45_ver, name, tuple(_v45_t_fp32.shape), + _v45_amax, _v45_ulp_max, + _v45_drift_max, _v45_drift_mean, + _v45_flips, _v45_eta, + ) + # keep one-version history + self._v45_prev_bf16[name] = ( + _v45_bf16.detach().clone() + ) + self._v45_prev_fp32[name] = ( + _v45_t_fp32.detach().clone() + ) + except Exception as _e_v45: + try: + self.logger.info( + "[MTPULPGap-v45] failure: %r", _e_v45, + ) + except Exception: + pass mtp_hf_tensors.extend( convert_to_hf( self.tf_config, From b862d68ffe7a4a2c2be2c31fe30d76e4cf15130d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 21:47:07 +0800 Subject: [PATCH 114/140] feat(megatron): trick --- areal/engine/megatron_engine.py | 202 ++++++++++++++++++++ examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 2 files changed, 203 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 405cc18994..937cf08090 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -203,6 +203,7 @@ def __init__(self, config: TrainEngineConfig): "v43:FixedLongProbe+MTPWeightHashDelta+CrossProcFix(AREAL_MTP_V30_DIAG)", "v44:MTPSrcHash+RepeatFixedLongProbe(AREAL_MTP_V30_DIAG)", "v45:MTPULPGap+DraftIPCStall(AREAL_MTP_V30_DIAG)", + "v46:ForceTickBf16+ShipFlips(AREAL_MTP_V46_FORCE_TICK)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -4596,6 +4597,143 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: else: _u = _tn_sd _had_prev = False + # [MTPForceTickBf16-v46] Cap draft-IPC + # stall at K syncs by promoting residual + # to ±ULP/2 when bf16 has not flipped in + # K_force consecutive syncs. Preserves + # long-run unbiasedness: the ±ULP/2 + # injection is a Σ-Δ quantum that the + # next sync's residual cancels. + try: + _ft_on_v46 = ( + _os_v16.environ.get( + 'AREAL_MTP_V46_FORCE_TICK', + '1', + ) == '1' + ) + except Exception: + _ft_on_v46 = True + if _ft_on_v46: + try: + _ft_k_v46 = int( + _os_v16.environ.get( + 'AREAL_MTP_V46_FORCE_TICK_K', + '8', + ) + ) + except Exception: + _ft_k_v46 = 8 + try: + _ft_ratio_v46 = float( + _os_v16.environ.get( + 'AREAL_MTP_V46_FORCE_TICK_RATIO', + '0.10', + ) + ) + except Exception: + _ft_ratio_v46 = 0.10 + if not hasattr( + self, '_mtp_v46_stale' + ): + self._mtp_v46_stale = {} + if not hasattr( + self, '_mtp_v46_prev_ship' + ): + self._mtp_v46_prev_ship = {} + _ft_amax = float( + _u.abs().max().item() + ) + if _ft_amax > 0: + import math as _m_v46 + _ft_e = _m_v46.floor( + _m_v46.log2(_ft_amax) + ) + _ft_ulp = 2.0 ** (_ft_e - 7) + else: + _ft_ulp = 0.0 + _ft_stale = ( + self._mtp_v46_stale.get(_nm_sd, 0) + ) + _ft_resid_absmax = 0.0 + if ( + _r_prev is not None + and _r_prev.shape == _tn_sd.shape + ): + try: + _ft_resid_absmax = float( + _r_prev.abs().max().item() + ) + except Exception: + _ft_resid_absmax = 0.0 + _ft_trigger_stale = ( + _ft_stale >= _ft_k_v46 + ) + _ft_trigger_ratio = ( + _ft_ulp > 0 + and _ft_ratio_v46 > 0 + and _ft_resid_absmax + >= _ft_ratio_v46 * _ft_ulp + ) + _ft_fired = False + if ( + _ft_trigger_stale + and _ft_ulp > 0 + ): + # Promote _u by sign(resid or + # drift) * ULP/2 on the single + # element with largest |resid| + # so that RNE flips exactly one + # bf16 bucket. Minimal, unbiased + # on average (residual carries + # opposite sign next sync). + try: + _ft_flat = _u.view(-1) + if _r_prev is not None: + _ft_signmap = ( + _r_prev.view(-1) + ) + else: + _ft_signmap = _ft_flat + _ft_sign = ( + _torch_v16.sign(_ft_signmap) + ) + _ft_sign = _torch_v16.where( + _ft_sign == 0, + _torch_v16.ones_like( + _ft_sign + ), + _ft_sign, + ) + _u = ( + _u + + _ft_sign.view_as(_u) + * (0.5 * _ft_ulp) + ) + _ft_fired = True + except Exception: + _ft_fired = False + self._mtp_v46_stale[_nm_sd] = ( + 0 if _ft_fired else _ft_stale + ) + # store diag for post-loop log + if not hasattr( + self, '_mtp_v46_fire_log' + ): + self._mtp_v46_fire_log = [] + if ( + _ft_fired + or _ft_trigger_ratio + or _ft_trigger_stale + ): + self._mtp_v46_fire_log.append( + ( + _nm_sd, + _ft_stale, + _ft_resid_absmax, + _ft_ulp, + _ft_fired, + ) + ) # RNE fp32 -> bf16 and retrieve actual # quantized fp32 value for residual calc. _bf16 = _u.to(_torch_v16.bfloat16) @@ -4625,6 +4763,70 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: mtp_hf_tensors[_i] = ( _nm_sd, _bf16.contiguous(), ) + # [MTPShipFlips-v46] update stale + # counter: if shipped bf16 payload + # matches previous version's shipped + # payload bit-for-bit, stale += 1. + try: + _ship_prev_v46 = ( + self._mtp_v46_prev_ship.get( + _nm_sd + ) + if hasattr( + self, '_mtp_v46_prev_ship' + ) + else None + ) + _ship_flips_v46 = -1 + if ( + _ship_prev_v46 is not None + and _ship_prev_v46.shape + == _bf16.shape + ): + _ship_flips_v46 = int( + ( + _bf16 != _ship_prev_v46 + ).sum().item() + ) + if _ship_flips_v46 == 0: + self._mtp_v46_stale[_nm_sd] = ( + self._mtp_v46_stale.get( + _nm_sd, 0 + ) + 1 + ) + else: + self._mtp_v46_stale[_nm_sd] = 0 + self._mtp_v46_prev_ship[_nm_sd] = ( + _bf16.detach().clone() + ) + # log per-tensor only if it fired + # or was previously stale. + if ( + _ft_fired + or _ship_flips_v46 == 0 + ): + self.logger.info( + '[MTPShipFlips-v46] ' + 'name=%s ship_flips=%s ' + 'stale=%s force_fired=%s ' + 'ulp=%.3e resid_absmax=%.3e', + _nm_sd, _ship_flips_v46, + self._mtp_v46_stale.get( + _nm_sd, 0 + ), + _ft_fired, + _ft_ulp if _ft_on_v46 else 0.0, + _ft_resid_absmax if _ft_on_v46 else 0.0, + ) + except Exception as _e_sf_v46: + try: + self.logger.info( + '[MTPShipFlips-v46] ' + 'failure name=%s err=%r', + _nm_sd, _e_sf_v46, + ) + except Exception: + pass _sd_applied += 1 if _shift_cnt > 0: _sd_total_shifted += _shift_cnt diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 6b09adf474..95a1fbbd92 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -82,7 +82,7 @@ actor: # MTP Online Training - keeps EAGLE draft heads aligned with evolving policy enable_mtp_training: true mtp_num_layers: 1 - mtp_loss_scaling_factor: 0.2 + mtp_loss_scaling_factor: 5 scheduling_spec: - task_type: worker From 6a98f7feaf6cfd0cdbf25f41dc6d6f625dd4885b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 22:11:58 +0800 Subject: [PATCH 115/140] fix: mtp_loss_scaling_factor --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 95a1fbbd92..57f7f20ed1 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -82,7 +82,7 @@ actor: # MTP Online Training - keeps EAGLE draft heads aligned with evolving policy enable_mtp_training: true mtp_num_layers: 1 - mtp_loss_scaling_factor: 5 + mtp_loss_scaling_factor: 1 scheduling_spec: - task_type: worker From 070c1a5f7fcc37c5da0b852ebe9f9b3313fdebb0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 4 May 2026 23:22:36 +0800 Subject: [PATCH 116/140] feat(megatron_engine): bf16 --- areal/engine/megatron_engine.py | 202 ++++++++++++++++++++ examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 2 files changed, 203 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 937cf08090..cbc42f8db2 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -204,6 +204,7 @@ def __init__(self, config: TrainEngineConfig): "v44:MTPSrcHash+RepeatFixedLongProbe(AREAL_MTP_V30_DIAG)", "v45:MTPULPGap+DraftIPCStall(AREAL_MTP_V30_DIAG)", "v46:ForceTickBf16+ShipFlips(AREAL_MTP_V46_FORCE_TICK)", + "v47:MTPMasterAmp(AREAL_MTP_V47_MASTER_AMP)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -1222,6 +1223,50 @@ def optimizer_step(self): "[SpecDecDiag-v20 D10] snapshot failed: %s", _e_d10, ) + # [MTPMasterAmp-v47] pre-step snapshot of MTP fp32 master. + # Captured before optimizer.step() so we can compute the + # raw Adam delta after the step and amplify it to bf16 ULP + # when needed. Gate: AREAL_MTP_V47_MASTER_AMP (default 1). + try: + import os as _os_v47pre + _v47_on = ( + _os_v47pre.environ.get( + 'AREAL_MTP_V47_MASTER_AMP', '1' + ) == '1' + ) + except Exception: + _v47_on = True + self._v47_pre_master = {} + self._v47_pre_data = {} + self._v47_on_step = bool( + _v47_on and getattr(self, 'enable_mtp_training', False) + ) + if self._v47_on_step and getattr(self, 'model', None) is not None: + try: + for _mod_v47 in self.model: + for _n_v47, _p_v47 in _mod_v47.named_parameters(): + if ('.mtp.' not in _n_v47 + and '.mtp_layers.' not in _n_v47): + continue + _mp_v47 = getattr(_p_v47, 'main_param', None) + if _mp_v47 is not None: + try: + self._v47_pre_master[_n_v47] = ( + _mp_v47.detach().clone() + ) + except Exception: + pass + try: + self._v47_pre_data[_n_v47] = ( + _p_v47.data.detach().clone() + ) + except Exception: + pass + except Exception as _e_v47pre: + self.logger.warning( + '[MTPMasterAmp-v47] pre-snapshot failed: %s', + _e_v47pre, + ) # [MTPGradProbe-v26] Install post-accumulate-grad hook on MTP # params (once) so grads are captured at the moment they land, # BEFORE Megatron's DistributedOptimizer consumes and frees them. @@ -1351,6 +1396,163 @@ def _hook(_p): self.logger.warning("[MTPGradProbe-v25] probe error: %s", _e) with trace_scope("megatron_engine.step"): update_successful, grad_norm, _ = self.optimizer.step() + # [MTPMasterAmp-v47] post-step delta amplification. + # For each MTP fp32 master tensor whose Adam step delta + # amax is smaller than beta * bf16_ULP, rescale the + # delta (preserving per-element sign / ratio) so that the + # shipped bf16 payload flips at least beta of a bucket + # each step. This breaks the bf16 ULP trap at compute + # time without distorting Adam's direction (we never + # touch the optimizer's internal m / v). + if ( + bool(update_successful) + and getattr(self, '_v47_on_step', False) + and getattr(self, 'model', None) is not None + ): + try: + import math as _m_v47 + import os as _os_v47p + import torch as _torch_v47p + try: + _beta_v47 = float( + _os_v47p.environ.get( + 'AREAL_MTP_V47_AMP_BETA', '0.5' + ) + ) + except Exception: + _beta_v47 = 0.5 + try: + _min_ratio_v47 = float( + _os_v47p.environ.get( + 'AREAL_MTP_V47_AMP_MIN_RATIO', '0.5' + ) + ) + except Exception: + _min_ratio_v47 = 0.5 + _n_amp_v47 = 0 + _n_skip_v47 = 0 + _scales_v47 = [] + for _mod_v47p in self.model: + for _n_v47p, _p_v47p in ( + _mod_v47p.named_parameters() + ): + if ('.mtp.' not in _n_v47p + and '.mtp_layers.' not in _n_v47p): + continue + _mp_v47p = getattr( + _p_v47p, 'main_param', None + ) + if _mp_v47p is None: + _n_skip_v47 += 1 + continue + _pre_v47 = self._v47_pre_master.get(_n_v47p) + if ( + _pre_v47 is None + or _pre_v47.shape != _mp_v47p.shape + ): + _n_skip_v47 += 1 + continue + _delta_v47 = _mp_v47p.data - _pre_v47 + _raw_dmax_v47 = float( + _delta_v47.abs().max().item() + ) + _amax_v47 = float( + _mp_v47p.data.abs().max().item() + ) + if _amax_v47 <= 0.0 or _raw_dmax_v47 <= 0.0: + _n_skip_v47 += 1 + continue + _e_v47 = _m_v47.floor(_m_v47.log2(_amax_v47)) + _ulp_v47 = 2.0 ** (_e_v47 - 7) + _target_v47 = _beta_v47 * _ulp_v47 + _ratio_v47 = _raw_dmax_v47 / _ulp_v47 + if _ratio_v47 >= _min_ratio_v47: + _n_skip_v47 += 1 + _log_this = False + _scale_v47 = 1.0 + _clipped = False + else: + _scale_v47 = ( + _target_v47 / _raw_dmax_v47 + ) + # cap to avoid runaway if Adam step is + # spuriously tiny (e.g. right after + # warmup) — hard ceiling 1e6. + _clipped = False + if _scale_v47 > 1.0e6: + _scale_v47 = 1.0e6 + _clipped = True + # write amplified delta back to fp32 + # master, leaving optimizer internals + # (m, v) unchanged. + _new_master_v47 = ( + _pre_v47 + _scale_v47 * _delta_v47 + ) + _mp_v47p.data.copy_(_new_master_v47) + # propagate to the bf16 model param so + # any downstream read path (including + # convert_to_hf) sees the new weight + # right now. + try: + _p_v47p.data.copy_( + _mp_v47p.data.to( + _p_v47p.data.dtype + ) + ) + except Exception: + pass + _n_amp_v47 += 1 + _scales_v47.append(_scale_v47) + _log_this = True + _amp_dmax_v47 = float( + ( + _mp_v47p.data - _pre_v47 + ).abs().max().item() + ) + if _log_this: + self.logger.info( + '[MTPMasterAmp-v47] name=%s ' + 'pre_amax=%.6e post_amax=%.6e ' + 'raw_dmax=%.3e amp_dmax=%.3e ' + 'ulp=%.3e beta=%.3f ' + 'scale=%.3e clipped=%s', + _n_v47p, + float(_pre_v47.abs().max().item()), + _amax_v47, + _raw_dmax_v47, _amp_dmax_v47, + _ulp_v47, _beta_v47, + _scale_v47, _clipped, + ) + # summary + if _scales_v47: + try: + import statistics as _st_v47 + _geo = _m_v47.exp( + _st_v47.fmean( + [_m_v47.log(s) for s in _scales_v47] + ) + ) + except Exception: + _geo = float('nan') + else: + _geo = float('nan') + self.logger.info( + '[MTPMasterAmpSummary-v47] ' + 'n_amplified=%d n_skipped=%d ' + 'geomean_scale=%s beta=%.3f ' + 'min_ratio=%.3f', + _n_amp_v47, _n_skip_v47, str(_geo), + _beta_v47, _min_ratio_v47, + ) + except Exception as _e_v47_post: + self.logger.warning( + '[MTPMasterAmp-v47] post-step failed: %s', + _e_v47_post, + ) + finally: + # release snapshots — memory-bounded. + self._v47_pre_master = {} + self._v47_pre_data = {} # [MTPPostOptim-v25] Diagnostic post-optimizer-step probe. try: _v25_post_seen = set() diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 57f7f20ed1..551944f03c 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -52,7 +52,7 @@ actor: max_tokens_per_mb: 10240 optimizer: type: adam - lr: 3e-6 + lr: 1e-5 weight_decay: 0.003 beta1: 0.9 beta2: 0.999 From 86c1e898edcacd509c0c2ec0a149d2ba5a16a415 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 09:37:56 +0800 Subject: [PATCH 117/140] feat(megatron_engine): fix --- areal/engine/megatron_engine.py | 158 +++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index cbc42f8db2..851a798e87 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -205,6 +205,7 @@ def __init__(self, config: TrainEngineConfig): "v45:MTPULPGap+DraftIPCStall(AREAL_MTP_V30_DIAG)", "v46:ForceTickBf16+ShipFlips(AREAL_MTP_V46_FORCE_TICK)", "v47:MTPMasterAmp(AREAL_MTP_V47_MASTER_AMP)", + "v48:MTPMasterCarry(AREAL_MTP_V48_MASTER_CARRY)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -1223,6 +1224,46 @@ def optimizer_step(self): "[SpecDecDiag-v20 D10] snapshot failed: %s", _e_d10, ) + # [MTPMasterCarry-v48] pre-step snapshot of MTP fp32 master. + # Mirrors v47 snapshot but is used by a residual-carry post-step + # that NEVER scales the Adam delta (see block below). Gate: + # AREAL_MTP_V48_MASTER_CARRY (default '1'). + try: + import os as _os_v48pre + _v48_on = ( + _os_v48pre.environ.get( + 'AREAL_MTP_V48_MASTER_CARRY', '1' + ) == '1' + ) + except Exception: + _v48_on = True + self._v48_pre_master = {} + self._v48_on_step = bool( + _v48_on and getattr(self, 'enable_mtp_training', False) + ) + if ( + self._v48_on_step + and getattr(self, 'model', None) is not None + ): + try: + for _mod_v48 in self.model: + for _n_v48, _p_v48 in _mod_v48.named_parameters(): + if ('.mtp.' not in _n_v48 + and '.mtp_layers.' not in _n_v48): + continue + _mp_v48 = getattr(_p_v48, 'main_param', None) + if _mp_v48 is not None: + try: + self._v48_pre_master[_n_v48] = ( + _mp_v48.detach().clone() + ) + except Exception: + pass + except Exception as _e_v48pre: + self.logger.warning( + '[MTPMasterCarry-v48] pre-snapshot failed: %s', + _e_v48pre, + ) # [MTPMasterAmp-v47] pre-step snapshot of MTP fp32 master. # Captured before optimizer.step() so we can compute the # raw Adam delta after the step and amplify it to bf16 ULP @@ -1231,11 +1272,11 @@ def optimizer_step(self): import os as _os_v47pre _v47_on = ( _os_v47pre.environ.get( - 'AREAL_MTP_V47_MASTER_AMP', '1' + 'AREAL_MTP_V47_MASTER_AMP', '0' ) == '1' ) except Exception: - _v47_on = True + _v47_on = False self._v47_pre_master = {} self._v47_pre_data = {} self._v47_on_step = bool( @@ -1553,6 +1594,119 @@ def _hook(_p): # release snapshots — memory-bounded. self._v47_pre_master = {} self._v47_pre_data = {} + # [MTPMasterCarry-v48] master-side Sigma-Delta residual carry. + # This is the v48 replacement for v47 (which scaled the whole delta + # by a tensor-wise scalar and destroyed model alignment in log.31). + # Here we NEVER touch the magnitude/direction of the Adam delta. + # Instead we maintain per-parameter fp32 residual and only flip the + # bf16 bucket for the elements whose accumulated residual exceeds + # +/- ULP/2, exactly like the ship-side v28 Σ-Δ but on the compute + # (master) side where it actually matters. + if ( + bool(update_successful) + and getattr(self, '_v48_on_step', False) + and getattr(self, 'model', None) is not None + ): + try: + import torch as _torch_v48 + if not hasattr(self, '_v48_residual'): + self._v48_residual = {} + _n_flipped_v48 = 0 + _n_seen_v48 = 0 + _max_res_ratio_v48 = 0.0 + _max_res_name_v48 = '' + for _mod_v48p in self.model: + for _n_v48p, _p_v48p in ( + _mod_v48p.named_parameters() + ): + if ('.mtp.' not in _n_v48p + and '.mtp_layers.' not in _n_v48p): + continue + _mp_v48p = getattr( + _p_v48p, 'main_param', None + ) + if _mp_v48p is None: + continue + _n_seen_v48 += 1 + # residual is fp32, same shape as main_param. + _res = self._v48_residual.get(_n_v48p) + if _res is None or _res.shape != _mp_v48p.shape: + _res = _torch_v48.zeros_like( + _mp_v48p.data, + dtype=_torch_v48.float32, + ) + self._v48_residual[_n_v48p] = _res + # accumulate: want = fp32_master + residual + _fp32_new = _mp_v48p.data.to(_torch_v48.float32) + _want = _fp32_new + _res + _bf16_dtype = _p_v48p.data.dtype + _bf16_new = _want.to(_bf16_dtype) + # new residual captures quantization loss (fp32 level) + _new_res = _want - _bf16_new.to(_torch_v48.float32) + # count how many bf16 elements flip relative to + # "no-carry" rounding of fp32_new alone + try: + _bf16_baseline = _fp32_new.to(_bf16_dtype) + _n_flip_this = int( + ( + _bf16_new.to(_torch_v48.float32) + != _bf16_baseline.to(_torch_v48.float32) + ).sum().item() + ) + except Exception: + _n_flip_this = -1 + # write back: master stays fp32-accurate; bf16 is + # quantized-with-accumulated-residual. + _mp_v48p.data.copy_(_want.to(_mp_v48p.dtype)) + try: + _p_v48p.data.copy_(_bf16_new) + except Exception: + pass + self._v48_residual[_n_v48p] = _new_res + # record residual magnitude ratio vs ULP + try: + import math as _m_v48ip + _amax = float( + _mp_v48p.data.abs().max().item() + ) + if _amax > 0.0: + _e = _m_v48ip.floor(_m_v48ip.log2(_amax)) + _ulp = 2.0 ** (_e - 7) + _rmax = float( + _new_res.abs().max().item() + ) + _ratio = _rmax / max(_ulp, 1e-30) + if _ratio > _max_res_ratio_v48: + _max_res_ratio_v48 = _ratio + _max_res_name_v48 = _n_v48p + # per-tensor log, cheap (O(#mtp params)) + self.logger.info( + '[MTPMasterCarry-v48] name=%s ' + 'amax=%.3e ulp=%.3e ' + 'res_amax=%.3e res_ratio=%.3f ' + 'flips=%d', + _n_v48p, _amax, _ulp, _rmax, + _ratio, _n_flip_this, + ) + except Exception: + pass + if _n_flip_this > 0: + _n_flipped_v48 += 1 + self.logger.info( + '[MTPMasterCarrySummary-v48] ' + 'n_seen=%d n_flipped_any=%d ' + 'max_res_ratio=%.3f max_res_name=%s', + _n_seen_v48, _n_flipped_v48, + _max_res_ratio_v48, _max_res_name_v48, + ) + except Exception as _e_v48_post: + self.logger.warning( + '[MTPMasterCarry-v48] post-step failed: %s', + _e_v48_post, + ) + finally: + # release pre-snapshot to bound memory + self._v48_pre_master = {} # [MTPPostOptim-v25] Diagnostic post-optimizer-step probe. try: _v25_post_seen = set() From 4f0d428e7bb6d4db99a188a15eb3c1cc00bd1a41 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 12:47:46 +0800 Subject: [PATCH 118/140] feat(megatron): mtp native --- areal/engine/megatron_engine.py | 71 ++++++++++++++++++++- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 + 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 851a798e87..8cd16395e3 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -206,6 +206,8 @@ def __init__(self, config: TrainEngineConfig): "v46:ForceTickBf16+ShipFlips(AREAL_MTP_V46_FORCE_TICK)", "v47:MTPMasterAmp(AREAL_MTP_V47_MASTER_AMP)", "v48:MTPMasterCarry(AREAL_MTP_V48_MASTER_CARRY)", + "v49:MTPLossClipTight+GradFp32Coerce+LossBoost(AREAL_MTP_V49_CLIP_TIGHT,AREAL_MTP_V49_GRAD_FP32_COERCE,AREAL_MTP_V49_LOSS_BOOST)", + "v50:MTPNativePassthrough(default-on; set AREAL_MTP_NATIVE_AUTOSCALER=0 to fall back to legacy FIFO)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -218,7 +220,7 @@ def __init__(self, config: TrainEngineConfig): "AREAL_MTP_SIGMA_DELTA_BF16", "1"), "AREAL_MTP_NATIVE_AUTOSCALER": _os_banner.environ.get( - "AREAL_MTP_NATIVE_AUTOSCALER", "0"), + "AREAL_MTP_NATIVE_AUTOSCALER", "1"), "AREAL_MTP_V30_DIAG": _os_banner.environ.get( "AREAL_MTP_V30_DIAG", "1"), @@ -1435,6 +1437,59 @@ def _hook(_p): ) except Exception as _e: self.logger.warning("[MTPGradProbe-v25] probe error: %s", _e) + # [MTPGradFp32Coerce-v50] Belt-and-suspenders fp32 upcast of MTP main_grad + # before optimizer.step. Passthrough (v50) aligns scale with slime/verl + # but bf16 grad accumulation bucket still truncates small updates across + # ~54 microbatches. Slime mitigates this with --accumulate-allreduce-grads + # -in-fp32; here we do the runtime equivalent on MTP params only. + # Gate: AREAL_MTP_V50_GRAD_FP32_COERCE (default "1"). + try: + import os as _os_v50g + _v50_gfp32 = ( + _os_v50g.environ.get('AREAL_MTP_V50_GRAD_FP32_COERCE', '1') == '1' + ) + except Exception: + _v50_gfp32 = True + if ( + _v50_gfp32 + and getattr(self, 'enable_mtp_training', False) + and getattr(self, 'model', None) is not None + ): + try: + import torch as _torch_v50g + _v50g_n = 0 + _v50g_amax = 0.0 + _v50g_name = '' + for _mod_v50g in self.model: + for _n_v50g, _p_v50g in _mod_v50g.named_parameters(): + if ('.mtp.' not in _n_v50g + and '.mtp_layers.' not in _n_v50g): + continue + _mg_v50g = getattr(_p_v50g, 'main_grad', None) + if _mg_v50g is None: + continue + if _mg_v50g.dtype == _torch_v50g.float32: + continue + try: + _fp32 = _mg_v50g.to(_torch_v50g.float32) + _p_v50g.main_grad = _fp32 + _v50g_n += 1 + _a = float(_fp32.abs().max().item()) + if _a > _v50g_amax: + _v50g_amax = _a + _v50g_name = _n_v50g + except Exception: + pass + if _v50g_n > 0: + self.logger.info( + '[MTPGradFp32Coerce-v50] coerced=%d ' + 'max_grad_amax=%.3e max_name=%s', + _v50g_n, _v50g_amax, _v50g_name, + ) + except Exception as _e_v50g: + self.logger.warning( + '[MTPGradFp32Coerce-v50] failed: %s', _e_v50g, + ) with trace_scope("megatron_engine.step"): update_successful, grad_norm, _ = self.optimizer.step() # [MTPMasterAmp-v47] post-step delta amplification. @@ -2247,16 +2302,26 @@ def _patched_postprocess( # Gated so the legacy behaviour remains # bit-exact by default. Enable with # AREAL_MTP_NATIVE_AUTOSCALER=1 + # [v50:MTPNativePassthrough] default-on. + # Passthrough via MTPLossAutoScaler.apply is the + # verl/slime-aligned path and in Megatron-Core 0.16.0 + # it is the ONLY numerically correct path: schedules.py + # sets main_loss_backward_scale = loss_scale / + # num_microbatches automatically after every + # forward_step, so the FIFO + DoubleScale inverse + # mechanism is strictly redundant and introduces + # bf16 rounding jitter. Set AREAL_MTP_NATIVE_AUTOSCALER=0 + # to fall back to legacy FIFO (diagnostic only). try: import os as _os_v17 _v17_on = ( _os_v17.environ.get( "AREAL_MTP_NATIVE_AUTOSCALER", - "0", + "1", ) == "1" ) except Exception: - _v17_on = False + _v17_on = True if _v17_on: try: from megatron.core.transformer.multi_token_prediction import ( diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 551944f03c..bb4640a88c 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -95,6 +95,8 @@ actor: megatron: mtp_num_layers: 1 mtp_loss_scaling_factor: 0.2 + ddp: + grad_reduce_in_fp32: true ref: backend: ${actor.backend} From 7eda213f5a8f255f0d28f55f105dcfc81ecae839 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 14:25:14 +0800 Subject: [PATCH 119/140] feat: MTP L2 norm --- areal/engine/megatron_engine.py | 68 +++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 8cd16395e3..7de8545d0a 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -208,6 +208,7 @@ def __init__(self, config: TrainEngineConfig): "v48:MTPMasterCarry(AREAL_MTP_V48_MASTER_CARRY)", "v49:MTPLossClipTight+GradFp32Coerce+LossBoost(AREAL_MTP_V49_CLIP_TIGHT,AREAL_MTP_V49_GRAD_FP32_COERCE,AREAL_MTP_V49_LOSS_BOOST)", "v50:MTPNativePassthrough(default-on; set AREAL_MTP_NATIVE_AUTOSCALER=0 to fall back to legacy FIFO)", + "v51:MTPGradClipNorm(default-on; AREAL_MTP_V51_GRAD_CLIP_NORM=, default=1.0)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -1490,6 +1491,73 @@ def _hook(_p): self.logger.warning( '[MTPGradFp32Coerce-v50] failed: %s', _e_v50g, ) + # [MTPGradClipNorm-v51] Per-component gradient L2-norm clipping for + # MTP parameters only, applied AFTER fp32 coerce and BEFORE + # optimizer.step(). Megatron-Core's `gradient_clipping=1.0` is a + # GLOBAL (backbone+MTP joint) norm, which lets MTP grad dominate + # when backbone grad is small (KL-regularised RL). slime mitigates + # this via `check_mtp_loss(max=1.0)` + `accumulate-allreduce-grads- + # in-fp32`; the latter is now on (YAML grad_reduce_in_fp32=true) + # but log.33 still shows per-step max|delta|=0.59-0.64 at v9-v13, + # correlated with PAW crashes v10=0.005 / v14=0.008. v51 adds the + # missing MTP-only norm clip so big spikes through MTPLossAutoScaler + # cannot push the draft head into a divergent region. + # Threshold: AREAL_MTP_V51_GRAD_CLIP_NORM (default 1.0). + # Disable: AREAL_MTP_V51_GRAD_CLIP_NORM=0 + try: + import os as _os_v51c + _v51_thr = float(_os_v51c.environ.get( + 'AREAL_MTP_V51_GRAD_CLIP_NORM', '1.0')) + except Exception: + _v51_thr = 1.0 + if ( + _v51_thr > 0.0 + and getattr(self, 'enable_mtp_training', False) + and getattr(self, 'model', None) is not None + ): + try: + import torch as _torch_v51c + _v51_grads = [] + _v51_names = [] + for _mod_v51c in self.model: + for _n_v51c, _p_v51c in _mod_v51c.named_parameters(): + if ('.mtp.' not in _n_v51c + and '.mtp_layers.' not in _n_v51c): + continue + _mg_v51c = getattr(_p_v51c, 'main_grad', None) + if _mg_v51c is None: + continue + _v51_grads.append(_mg_v51c) + _v51_names.append(_n_v51c) + if _v51_grads: + _v51_total_sq = _torch_v51c.zeros( + (), dtype=_torch_v51c.float32, + device=_v51_grads[0].device) + for _g in _v51_grads: + _v51_total_sq = _v51_total_sq + ( + _g.detach().float().pow(2).sum()) + _v51_norm = float(_v51_total_sq.sqrt().item()) + _v51_clipped = False + _v51_scale = 1.0 + if _v51_norm > _v51_thr and _v51_norm > 0.0: + _v51_scale = _v51_thr / (_v51_norm + 1e-12) + for _g in _v51_grads: + _g.mul_(_v51_scale) + _v51_clipped = True + _gs_v51 = getattr(self, '_global_step', 0) + if (_gs_v51 <= 5 or _gs_v51 % 50 == 0 + or _v51_clipped): + self.logger.info( + '[MTPGradClipNorm-v51] step=%d n_params=%d ' + 'mtp_grad_norm=%.4e threshold=%.4e ' + 'clipped=%s scale=%.4e', + _gs_v51, len(_v51_grads), _v51_norm, + _v51_thr, _v51_clipped, _v51_scale, + ) + except Exception as _e_v51c: + self.logger.warning( + '[MTPGradClipNorm-v51] failed: %s', _e_v51c, + ) with trace_scope("megatron_engine.step"): update_successful, grad_norm, _ = self.optimizer.step() # [MTPMasterAmp-v47] post-step delta amplification. From 2fd3890b190082b75405e7e9bde82844058dc221 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 15:40:05 +0800 Subject: [PATCH 120/140] feat(engine): MTPSourceLossCap --- areal/engine/megatron_engine.py | 62 +++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 7de8545d0a..f20397bff2 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -208,7 +208,8 @@ def __init__(self, config: TrainEngineConfig): "v48:MTPMasterCarry(AREAL_MTP_V48_MASTER_CARRY)", "v49:MTPLossClipTight+GradFp32Coerce+LossBoost(AREAL_MTP_V49_CLIP_TIGHT,AREAL_MTP_V49_GRAD_FP32_COERCE,AREAL_MTP_V49_LOSS_BOOST)", "v50:MTPNativePassthrough(default-on; set AREAL_MTP_NATIVE_AUTOSCALER=0 to fall back to legacy FIFO)", - "v51:MTPGradClipNorm(default-on; AREAL_MTP_V51_GRAD_CLIP_NORM=, default=1.0)", + "v51:MTPGradClipNorm(diag-only; AREAL_MTP_V51_GRAD_CLIP_NORM, default=0 after v52)", + "v52:MTPSourceLossCap(default-on; AREAL_MTP_V52_LOSS_CAP_RATIO=, default=2.0)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -1507,7 +1508,7 @@ def _hook(_p): try: import os as _os_v51c _v51_thr = float(_os_v51c.environ.get( - 'AREAL_MTP_V51_GRAD_CLIP_NORM', '1.0')) + 'AREAL_MTP_V51_GRAD_CLIP_NORM', '0')) except Exception: _v51_thr = 1.0 if ( @@ -2350,6 +2351,63 @@ def _patched_postprocess( _mtp_loss_to_store = mtp_loss_scale * mtp_loss else: _mtp_loss_to_store = mtp_loss_scale * mtp_loss / num_tokens + # [MTPSourceLossCap-v52] Adaptive source-side soft-cap on + # _mtp_loss_to_store BEFORE it is appended to FIFO and BEFORE + # MTPLossAutoScaler.apply. v51 clipped main_grad after backward + # but Adam's m/sqrt(v) normalisation made that ineffective + # (log.34: max|delta|=0.63 at v9 almost unchanged vs log.33=0.64). + # v52 scales the loss ITSELF, which autograd propagates as a + # magnitude reduction on the injected gradient without touching + # direction -- effective for both the FIFO/legacy path and the + # v50 passthrough path. Threshold tracks an EMA of |sum(loss)|; + # default cap = ratio * EMA, ratio via + # AREAL_MTP_V52_LOSS_CAP_RATIO (default 2.0). + # Disable: AREAL_MTP_V52_LOSS_CAP_RATIO=0 + try: + import os as _os_v52s + _v52_ratio = float(_os_v52s.environ.get( + 'AREAL_MTP_V52_LOSS_CAP_RATIO', '2.0')) + except Exception: + _v52_ratio = 2.0 + if _v52_ratio > 0.0: + try: + _v52_abs_sum = float( + _mtp_loss_to_store.detach().float().abs().sum().item() + ) + _v52_ema_prev = getattr( + _engine_ref, '_v52_loss_abs_sum_ema', None) + if _v52_ema_prev is None or _v52_ema_prev <= 0.0: + _v52_ema = _v52_abs_sum + else: + _v52_ema = 0.95 * _v52_ema_prev + 0.05 * _v52_abs_sum + _engine_ref._v52_loss_abs_sum_ema = _v52_ema + _v52_cap = _v52_ratio * _v52_ema + _v52_capped = False + _v52_scale = 1.0 + if (_v52_cap > 0.0 + and _v52_abs_sum > _v52_cap): + _v52_scale = _v52_cap / (_v52_abs_sum + 1e-12) + _mtp_loss_to_store = ( + _mtp_loss_to_store * _v52_scale) + _v52_capped = True + _v52_ctr = getattr( + _engine_ref, '_v52_cap_ctr', 0) + 1 + _engine_ref._v52_cap_ctr = _v52_ctr + if (_v52_ctr <= 5 or _v52_ctr % 50 == 0 + or _v52_capped): + _logger.info( + '[MTPSourceLossCap-v52] call=%d ' + 'abs_sum=%.4e ema=%.4e cap=%.4e ' + 'ratio=%.2f capped=%s scale=%.4e', + _v52_ctr, _v52_abs_sum, _v52_ema, + _v52_cap, _v52_ratio, _v52_capped, + _v52_scale, + ) + except Exception as _e_v52s: + _logger.warning( + '[MTPSourceLossCap-v52] failed: %s', + _e_v52s, + ) _engine_ref._mtp_loss_for_backward.append(_mtp_loss_to_store) # --- BEGIN --- From 57b4e4d3bc887b75165b96cfe966f000a4df3a61 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 16:58:59 +0800 Subject: [PATCH 121/140] refactor(math): remove config --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index bb4640a88c..1a334a8031 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -93,8 +93,6 @@ actor: env_vars: {} megatron: - mtp_num_layers: 1 - mtp_loss_scaling_factor: 0.2 ddp: grad_reduce_in_fp32: true From fd8fb4a5e70c2000398b8b0ac4a482d77c601e05 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 17:03:12 +0800 Subject: [PATCH 122/140] fix: fix mtp_loss_scaling_factor --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 1a334a8031..723d73b176 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -82,7 +82,7 @@ actor: # MTP Online Training - keeps EAGLE draft heads aligned with evolving policy enable_mtp_training: true mtp_num_layers: 1 - mtp_loss_scaling_factor: 1 + mtp_loss_scaling_factor: 0.2 scheduling_spec: - task_type: worker From 2f6f69c9b43f5687d96e7b7731f99e8d4994da14 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 18:11:39 +0800 Subject: [PATCH 123/140] fix: mtp_loss_scaling_factor --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 723d73b176..7a3030deaf 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -82,7 +82,7 @@ actor: # MTP Online Training - keeps EAGLE draft heads aligned with evolving policy enable_mtp_training: true mtp_num_layers: 1 - mtp_loss_scaling_factor: 0.2 + mtp_loss_scaling_factor: 0.05 scheduling_spec: - task_type: worker From fda958bb48437b5c3a09d6c032f28d4e32f34614 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 20:44:33 +0800 Subject: [PATCH 124/140] fix(engine): remove shared weight --- areal/engine/megatron_engine.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index f20397bff2..b237d5faf1 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -210,6 +210,7 @@ def __init__(self, config: TrainEngineConfig): "v50:MTPNativePassthrough(default-on; set AREAL_MTP_NATIVE_AUTOSCALER=0 to fall back to legacy FIFO)", "v51:MTPGradClipNorm(diag-only; AREAL_MTP_V51_GRAD_CLIP_NORM, default=0 after v52)", "v52:MTPSourceLossCap(default-on; AREAL_MTP_V52_LOSS_CAP_RATIO=, default=2.0)", + "v53:MTPSharedWeightIsolate(detach output_weight for MTP output_layer)", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -2180,6 +2181,29 @@ def _patched_postprocess( self_model.shared_embedding_or_output_weight() ) + # ---------------------------------------------------------------- + # [MTPSharedWeightIsolate-v53] + # When share_embeddings_and_output_weights=True the returned + # `output_weight` IS the shared parameter tensor. If we pass it + # directly to `output_layer(weight=...)` inside the MTP loop + # below, the MTP CE backward will accumulate gradient on that + # shared parameter, contaminating: + # - the embedding lookup used by the main model, and + # - the sglang spec-decoding weight sync (mtp_hf_tensors), + # which empirically drives spec_accept_rate / PAW to collapse + # within ~13 versions (see round 12 log comparison). + # + # Fix: snapshot a *detached* view of the weight specifically + # for the MTP branch. The main-path `output_layer(... weight= + # output_weight ...)` call below is LEFT UNTOUCHED so GRPO + # gradient on lm_head / embedding is preserved exactly. + # ---------------------------------------------------------------- + _mtp_output_weight_v53 = ( + output_weight.detach() + if output_weight is not None + else None + ) + if mtp_in_postprocess: hidden_states = self_model.mtp( input_ids=input_ids, @@ -2237,7 +2261,10 @@ def _patched_postprocess( hidden_states.requires_grad) mtp_logits, _ = self_model.output_layer( _mtp_hs, - weight=output_weight, + # [MTPSharedWeightIsolate-v53] detached weight + # prevents MTP CE grad from contaminating + # shared embedding / lm_head parameter. + weight=_mtp_output_weight_v53, runtime_gather_output=runtime_gather_output, ) # Diagnostic: verify gradient chain is intact From ce448e7dd81d72cfee31c7defe999ff7c7ce44eb Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 23:09:27 +0800 Subject: [PATCH 125/140] feat(megatron_engine): SpecDecFlow --- areal/engine/megatron_engine.py | 506 ++++++++++++++++++++++++++++++++ 1 file changed, 506 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index b237d5faf1..ad0afe2f41 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -211,6 +211,7 @@ def __init__(self, config: TrainEngineConfig): "v51:MTPGradClipNorm(diag-only; AREAL_MTP_V51_GRAD_CLIP_NORM, default=0 after v52)", "v52:MTPSourceLossCap(default-on; AREAL_MTP_V52_LOSS_CAP_RATIO=, default=2.0)", "v53:MTPSharedWeightIsolate(detach output_weight for MTP output_layer)", + "v54:MTPFreezeGate+DraftEMA+SpecDecFlowLog(AREAL_MTP_V54_FREEZE[default=0],AREAL_MTP_V54_DRAFT_EMA[default=0.0],AREAL_MTP_V54_SPEC_FLOW_LOG[default=1])", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -1560,8 +1561,307 @@ def _hook(_p): self.logger.warning( '[MTPGradClipNorm-v51] failed: %s', _e_v51c, ) + # [SpecDecFlow-v54] PRE_STEP stage — per-MTP-param grad + # diagnostics BEFORE optimizer step. Captures what the + # optimizer is about to apply. Default ON. + try: + import os as _os_v54p + _v54_flow_on = ( + _os_v54p.environ.get( + 'AREAL_MTP_V54_SPEC_FLOW_LOG', '1', + ) == '1' + and getattr(self, 'enable_mtp_training', False) + and getattr(self, 'model', None) is not None + ) + if _v54_flow_on: + _v54_pre_seen = 0 + _v54_pre_with_grad = 0 + _v54_pre_mp_avail = 0 + _v54_pre_mg_avail = 0 + _v54_pre_grad_norm_sq = 0.0 + if not hasattr(self, '_v54_pre_snap'): + self._v54_pre_snap = {} + for _mod_v54p in self.model: + for _n_v54p, _p_v54p in ( + _mod_v54p.named_parameters() + ): + if ('.mtp.' not in _n_v54p + and '.mtp_layers.' not in _n_v54p): + continue + _v54_pre_seen += 1 + _g_v54p = getattr(_p_v54p, 'grad', None) + _g_norm_v54p = -1.0 + _g_amax_v54p = -1.0 + if _g_v54p is not None: + try: + _g_norm_v54p = float( + _g_v54p.detach().float() + .norm().item() + ) + _g_amax_v54p = float( + _g_v54p.detach().abs() + .max().item() + ) + _v54_pre_grad_norm_sq += ( + _g_norm_v54p * _g_norm_v54p + ) + _v54_pre_with_grad += 1 + except Exception: + pass + _mp_v54p = getattr( + _p_v54p, 'main_param', None, + ) + _mg_amax_v54p = -1.0 + if _mp_v54p is not None: + _v54_pre_mp_avail += 1 + _mg_v54p = getattr( + _mp_v54p, 'main_grad', None, + ) + if _mg_v54p is None: + _mg_v54p = getattr( + _mp_v54p, 'grad', None, + ) + if _mg_v54p is not None: + _v54_pre_mg_avail += 1 + try: + _mg_amax_v54p = float( + _mg_v54p.detach().abs() + .max().item() + ) + except Exception: + pass + try: + self._v54_pre_snap[_n_v54p] = ( + _mp_v54p.detach().float().clone() + ) + except Exception: + pass + self.logger.info( + '[SpecDecFlow-v54] stage=pre_step ' + 'name=%s shape=%s dtype=%s ' + 'grad_norm=%.6e grad_amax=%.6e ' + 'main_param_present=%s ' + 'main_grad_amax=%.6e', + _n_v54p, + str(tuple(_p_v54p.shape)), + str(_p_v54p.dtype), + _g_norm_v54p, _g_amax_v54p, + str(_mp_v54p is not None), + _mg_amax_v54p, + ) + _v54_pre_grad_norm = ( + _v54_pre_grad_norm_sq ** 0.5 + ) + self.logger.info( + '[SpecDecFlow-v54] stage=pre_step_summary ' + 'n_mtp_params=%d n_with_grad=%d ' + 'n_main_param=%d n_main_grad=%d ' + 'mtp_grad_norm=%.6e', + _v54_pre_seen, _v54_pre_with_grad, + _v54_pre_mp_avail, _v54_pre_mg_avail, + _v54_pre_grad_norm, + ) + except Exception as _e_v54p: + try: + self.logger.warning( + '[SpecDecFlow-v54] pre_step failed: %r', _e_v54p, + ) + except Exception: + pass + # [MTPFreezeGate-v54] Disambiguation/mitigation control. + # When AREAL_MTP_V54_FREEZE=1 (default '0'=off), zero every + # MTP parameter's .grad AND its main_param.grad/main_grad + # right before the Megatron distributed-optimizer step. + # This cleanly freezes every MTP tensor (enorm/hnorm/ + # eh_proj/transformer_layer/final_layernorm/shared_head), + # leaving the rest of the model to be trained normally by + # GRPO. Any subsequent shipment to sglang will contain + # bit-identical MTP weights. + # + # Usage: set AREAL_MTP_V54_FREEZE=1 for one run. If + # rollout/spec_accept_rate stops declining, MTP weight + # drift (H1) is the cause and EMA should be tuned. + # If the rate still declines, main-model GRPO drift of + # the hidden-state distribution (H2) is the cause and a + # different mitigation is needed. + try: + import os as _os_v54f + _v54_freeze = ( + _os_v54f.environ.get('AREAL_MTP_V54_FREEZE', '0') + == '1' + and getattr(self, 'enable_mtp_training', False) + and getattr(self, 'model', None) is not None + ) + self._v54_freeze_engaged = bool(_v54_freeze) + if _v54_freeze: + _v54_n_zeroed = 0 + for _mod_v54f in self.model: + for _n_v54f, _p_v54f in ( + _mod_v54f.named_parameters() + ): + if ('.mtp.' not in _n_v54f + and '.mtp_layers.' not in _n_v54f): + continue + try: + if _p_v54f.grad is not None: + _p_v54f.grad.detach().zero_() + except Exception: + pass + _mp_v54f = getattr( + _p_v54f, 'main_param', None, + ) + if _mp_v54f is not None: + try: + _mg_v54f = getattr( + _mp_v54f, 'grad', None, + ) + if _mg_v54f is not None: + _mg_v54f.detach().zero_() + except Exception: + pass + _mgf_v54f = getattr( + _mp_v54f, 'main_grad', None, + ) + if _mgf_v54f is not None: + try: + _mgf_v54f.detach().zero_() + except Exception: + pass + _v54_n_zeroed += 1 + self.logger.info( + '[SpecDecFlow-v54] stage=freeze ' + 'name=%s zeroed=True', _n_v54f, + ) + self.logger.info( + '[SpecDecFlow-v54] stage=freeze_summary ' + 'n_zeroed=%d', _v54_n_zeroed, + ) + self.logger.info( + '[MTPFreezeGate-v54] zeroed grads for %d MTP ' + 'params before optimizer.step()', _v54_n_zeroed, + ) + except Exception as _e_v54f: + try: + self.logger.warning( + '[MTPFreezeGate-v54] gate failed: %r', _e_v54f, + ) + except Exception: + pass with trace_scope("megatron_engine.step"): update_successful, grad_norm, _ = self.optimizer.step() + # [SpecDecFlow-v54] POST_STEP stage — per-MTP-param delta + # diagnostics AFTER optimizer step. Captures what the + # optimizer actually applied (fp32 master delta). + try: + import os as _os_v54q + _v54_flow_on2 = ( + _os_v54q.environ.get( + 'AREAL_MTP_V54_SPEC_FLOW_LOG', '1', + ) == '1' + and bool(update_successful) + and getattr(self, 'enable_mtp_training', False) + and getattr(self, 'model', None) is not None + ) + if _v54_flow_on2: + _v54_post_seen = 0 + _v54_post_stalled = 0 + _v54_post_max_delta = 0.0 + _v54_post_max_name = '' + for _mod_v54q in self.model: + for _n_v54q, _p_v54q in ( + _mod_v54q.named_parameters() + ): + if ('.mtp.' not in _n_v54q + and '.mtp_layers.' not in _n_v54q): + continue + _v54_post_seen += 1 + _mp_v54q = getattr( + _p_v54q, 'main_param', None, + ) + _delta_amax = -1.0 + _delta_l2 = -1.0 + _post_amax = -1.0 + _bf16_cast_diff = -1.0 + if _mp_v54q is not None: + try: + _pre_v54q = getattr( + self, '_v54_pre_snap', {}, + ).get(_n_v54q) + _cur_fp32 = ( + _mp_v54q.detach().float() + ) + _post_amax = float( + _cur_fp32.abs().max().item() + ) + if ( + _pre_v54q is not None + and _pre_v54q.shape + == _cur_fp32.shape + ): + _d = _cur_fp32 - _pre_v54q + _delta_amax = float( + _d.abs().max().item() + ) + _delta_l2 = float( + _d.norm().item() + ) + if _delta_amax == 0.0: + _v54_post_stalled += 1 + if _delta_amax > ( + _v54_post_max_delta + ): + _v54_post_max_delta = ( + _delta_amax + ) + _v54_post_max_name = ( + _n_v54q + ) + try: + _bf = _p_v54q.data.detach() + _bf16_cast_diff = float( + ( + _cur_fp32 + - _bf.float() + ).abs().max().item() + ) + except Exception: + pass + except Exception: + pass + self.logger.info( + '[SpecDecFlow-v54] stage=post_step ' + 'name=%s post_amax=%.6e ' + 'delta_amax=%.6e delta_l2=%.6e ' + 'bf16_cast_diff=%.6e', + _n_v54q, _post_amax, + _delta_amax, _delta_l2, + _bf16_cast_diff, + ) + self.logger.info( + '[SpecDecFlow-v54] stage=post_step_summary ' + 'n_mtp_params=%d n_stalled=%d ' + 'max_delta=%.6e max_delta_name=%s ' + 'freeze_engaged=%s', + _v54_post_seen, _v54_post_stalled, + _v54_post_max_delta, + _v54_post_max_name, + str(getattr( + self, '_v54_freeze_engaged', False, + )), + ) + # release pre-snapshot + try: + self._v54_pre_snap = {} + except Exception: + pass + except Exception as _e_v54q: + try: + self.logger.warning( + '[SpecDecFlow-v54] post_step failed: %r', + _e_v54q, + ) + except Exception: + pass # [MTPMasterAmp-v47] post-step delta amplification. # For each MTP fp32 master tensor whose Adam step delta # amax is smaller than beta * bf16_ULP, rescale the @@ -3660,6 +3960,212 @@ def _update_bucket_weights_from_distributed( self.logger.warning( "[MTPSendPreBcast-v25] probe error: %s", _e_v25s, ) + # [MTPDraftEMA-v54] Optional EMA smoothing of the bf16 wire + # payload shipped to sglang, applied right before the RPC + # update_weights_from_distributed() call. alpha in (0,1) + # produces: + # W_ship[t] = (1-alpha) * W_ship[t-1] + alpha * W_train[t] + # dampening per-step MTP update noise as seen by the EAGLE + # draft head. alpha==0.0 (default) or alpha==1.0 is + # pass-through (feature disabled / no smoothing). + _v54_ema_applied_names = set() + try: + import os as _os_v54e + _v54_alpha_raw = _os_v54e.environ.get( + 'AREAL_MTP_V54_DRAFT_EMA', '0.0', + ) + try: + _v54_alpha = float(_v54_alpha_raw) + except Exception: + _v54_alpha = 0.0 + _v54_ema_on = (0.0 < _v54_alpha < 1.0) + self._v54_ema_alpha = _v54_alpha + self._v54_ema_enabled = _v54_ema_on + if _v54_ema_on: + if not hasattr(self, '_v54_ema_state'): + self._v54_ema_state = {} + _v54_ema_n = 0 + for _v54_idx, _v54_np in enumerate( + converted_named_tensors + ): + _v54_name, _v54_param = _v54_np + if not ( + '.enorm' in _v54_name + or '.hnorm' in _v54_name + or '.eh_proj' in _v54_name + or '.shared_head.' in _v54_name + or '.mtp_layers.' in _v54_name + ): + continue + try: + _v54_cur = _v54_param.data + _v54_prev = self._v54_ema_state.get(_v54_name) + _v54_pre_norm = float( + _v54_cur.float().norm().item() + ) + if ( + _v54_prev is not None + and _v54_prev.shape == _v54_cur.shape + ): + _v54_smoothed = ( + (1.0 - _v54_alpha) * _v54_prev.to( + torch.float32 + ) + + _v54_alpha * _v54_cur.to( + torch.float32 + ) + ) + _v54_smoothed = _v54_smoothed.to( + _v54_cur.dtype + ).contiguous() + _v54_param.data.copy_(_v54_smoothed) + self._v54_ema_state[_v54_name] = ( + _v54_smoothed.detach().clone() + ) + _v54_ema_applied_names.add(_v54_name) + _v54_ema_n += 1 + _v54_post_norm = float( + _v54_param.data.float() + .norm().item() + ) + self.logger.info( + '[SpecDecFlow-v54] stage=ema ' + 'name=%s alpha=%.4f ' + 'pre_norm=%.6e post_norm=%.6e ' + 'applied=True', + _v54_name, _v54_alpha, + _v54_pre_norm, _v54_post_norm, + ) + else: + self._v54_ema_state[_v54_name] = ( + _v54_cur.detach().clone() + ) + self.logger.info( + '[SpecDecFlow-v54] stage=ema ' + 'name=%s alpha=%.4f ' + 'pre_norm=%.6e post_norm=%.6e ' + 'applied=False reason=seed', + _v54_name, _v54_alpha, + _v54_pre_norm, _v54_pre_norm, + ) + except Exception: + continue + self.logger.info( + '[MTPDraftEMA-v54] applied alpha=%.4f to %d ' + 'MTP wire tensors (cache_size=%d)', + _v54_alpha, _v54_ema_n, + len(self._v54_ema_state), + ) + self.logger.info( + '[SpecDecFlow-v54] stage=ema_summary ' + 'alpha=%.4f n_applied=%d cache_size=%d', + _v54_alpha, _v54_ema_n, + len(self._v54_ema_state), + ) + except Exception as _e_v54e: + try: + self.logger.warning( + '[MTPDraftEMA-v54] gate failed: %r', _e_v54e, + ) + except Exception: + pass + # [SpecDecFlow-v54] SHIP stage — per-MTP-wire-tensor payload + # diagnostics right before dist.broadcast(). Answers: + # 'exactly which bytes are shipped to sglang this version?'. + try: + import os as _os_v54s + _v54_flow_on3 = ( + _os_v54s.environ.get( + 'AREAL_MTP_V54_SPEC_FLOW_LOG', '1', + ) == '1' + ) + if _v54_flow_on3: + _v54_ship_n = 0 + _v54_ship_bytes = 0 + _v54_ship_sq = 0.0 + _v54_ship_cnt = 0 + _v54_ship_first = None + _v54_ship_first_l2 = -1.0 + _v54_ship_mtp_only = 0 + for _v54_si, (_v54_sn, _v54_st) in enumerate( + converted_named_tensors + ): + _is_mtp = ( + '.enorm' in _v54_sn + or '.hnorm' in _v54_sn + or '.eh_proj' in _v54_sn + or '.shared_head.' in _v54_sn + or '.mtp_layers.' in _v54_sn + ) + if not _is_mtp: + continue + _v54_ship_mtp_only += 1 + try: + _td = _v54_st.detach() + _tf = _td.float() + _l2 = float(_tf.norm().item()) + _am = float(_tf.abs().mean().item()) + _ax = float(_tf.abs().max().item()) + _v54_ship_sq += _l2 * _l2 + _v54_ship_cnt += int(_tf.numel()) + _v54_ship_bytes += int( + _td.numel() * _td.element_size() + ) + _v54_ship_n += 1 + if _v54_ship_first is None: + _v54_ship_first = _v54_sn + _v54_ship_first_l2 = _l2 + self.logger.info( + '[SpecDecFlow-v54] stage=ship ' + 'idx=%d name=%s dtype=%s shape=%s ' + 'l2=%.6e abs_mean=%.6e abs_max=%.6e ' + 'ema_applied=%s', + _v54_si, _v54_sn, str(_td.dtype), + str(tuple(_td.shape)), + _l2, _am, _ax, + str(_v54_sn in _v54_ema_applied_names), + ) + except Exception: + continue + _v54_wire_norm = _v54_ship_sq ** 0.5 + _v54_prev_wire = getattr( + self, '_v54_prev_wire_norm', None, + ) + self._v54_prev_wire_norm = _v54_wire_norm + _v54_d_wire = -1.0 + if _v54_prev_wire is not None: + _v54_d_wire = abs( + _v54_wire_norm - _v54_prev_wire + ) + self.logger.info( + '[SpecDecFlow-v54] stage=ship_summary ' + 'version=%s n_mtp_shipped=%d ' + 'total_bytes=%d wire_norm=%.6e ' + 'd_wire_norm=%.6e first=%s first_l2=%.6e ' + 'ema_enabled=%s ema_alpha=%.4f ' + 'freeze_engaged=%s', + str(getattr(meta, 'version', 'NA')), + _v54_ship_n, _v54_ship_bytes, + _v54_wire_norm, _v54_d_wire, + str(_v54_ship_first), + _v54_ship_first_l2, + str(getattr( + self, '_v54_ema_enabled', False, + )), + float(getattr( + self, '_v54_ema_alpha', 0.0, + )), + str(getattr( + self, '_v54_freeze_engaged', False, + )), + ) + except Exception as _e_v54s: + try: + self.logger.warning( + '[SpecDecFlow-v54] ship failed: %r', _e_v54s, + ) + except Exception: + pass _t_post0 = _diag_time.time() fut = self.rollout_engine.update_weights_from_distributed(meta, param_specs) self.logger.info( From da77672c0765b010d75304e26dbddc72f1a45077 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 23:59:28 +0800 Subject: [PATCH 126/140] feat: learn rate --- areal/engine/megatron_engine.py | 161 ++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ad0afe2f41..4864848b71 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -212,6 +212,7 @@ def __init__(self, config: TrainEngineConfig): "v52:MTPSourceLossCap(default-on; AREAL_MTP_V52_LOSS_CAP_RATIO=, default=2.0)", "v53:MTPSharedWeightIsolate(detach output_weight for MTP output_layer)", "v54:MTPFreezeGate+DraftEMA+SpecDecFlowLog(AREAL_MTP_V54_FREEZE[default=0],AREAL_MTP_V54_DRAFT_EMA[default=0.0],AREAL_MTP_V54_SPEC_FLOW_LOG[default=1])", + "v55:MTPLRBoost(AREAL_MTP_V55_MTP_LR_BOOST[default=1.0])", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -1747,6 +1748,166 @@ def _hook(_p): ) except Exception: pass + # [MTPLRBoost-v55] Boost MTP gradient learning rate by a + # configurable multiplier just before optimizer.step(). + # Evidence-driven minimal fix: log.42 (Run A, v54 freeze=1) + # vs log.41 (Run B, v53) confirmed H2 — decline is + # dominated by main-model hidden-state drift, not MTP + # weight drift. In slime / verl-style EAGLE RL training + # the draft (MTP) head needs to track main-model drift + # faster than vanilla co-training allows; the standard + # pattern is an MTP-specific LR multiplier so the draft + # head learns faster than the target. + # Default 1.0 = exact baseline (full no-op). Skip + # entirely when v54 freeze is engaged (cannot scale + # zeroed grads meaningfully). + try: + import os as _os_v55b + _v55_mult_raw = _os_v55b.environ.get( + 'AREAL_MTP_V55_MTP_LR_BOOST', '1.0', + ) + try: + _v55_mult = float(_v55_mult_raw) + except Exception: + _v55_mult = 1.0 + _v55_freeze_engaged = bool( + getattr(self, '_v54_freeze_engaged', False) + ) + _v55_active = ( + _v55_mult > 1.0 + and not _v55_freeze_engaged + and getattr(self, 'enable_mtp_training', False) + and getattr(self, 'model', None) is not None + ) + self._v55_lr_boost_active = bool(_v55_active) + self._v55_lr_boost_mult = float(_v55_mult) + if _v55_mult > 1.0 and _v55_freeze_engaged: + try: + self.logger.info( + '[MTPLRBoost-v55] ' + 'skipped reason=freeze_engaged' + ) + except Exception: + pass + elif _v55_active: + _v55_n_scaled = 0 + _v55_pre_sq = 0.0 + _v55_post_sq = 0.0 + for _mod_v55 in self.model: + for _n_v55, _p_v55 in ( + _mod_v55.named_parameters() + ): + if ('.mtp.' not in _n_v55 + and '.mtp_layers.' not in _n_v55): + continue + _v55_scaled_any = False + try: + _g_v55 = getattr(_p_v55, 'grad', None) + if _g_v55 is not None: + _gn = float( + _g_v55.detach().float() + .norm().item() + ) + _v55_pre_sq += _gn * _gn + _g_v55.detach().mul_(_v55_mult) + _gn2 = float( + _g_v55.detach().float() + .norm().item() + ) + _v55_post_sq += _gn2 * _gn2 + _v55_scaled_any = True + except Exception: + pass + _mp_v55 = getattr( + _p_v55, 'main_param', None, + ) + if _mp_v55 is not None: + try: + _mg_v55 = getattr( + _mp_v55, 'grad', None, + ) + if _mg_v55 is not None: + if not _v55_scaled_any: + _gn = float( + _mg_v55.detach() + .float().norm() + .item() + ) + _v55_pre_sq += _gn * _gn + _mg_v55.detach().mul_( + _v55_mult + ) + _gn2 = float( + _mg_v55.detach() + .float().norm() + .item() + ) + _v55_post_sq += ( + _gn2 * _gn2 + ) + _v55_scaled_any = True + else: + _mg_v55.detach().mul_( + _v55_mult + ) + except Exception: + pass + try: + _mgf_v55 = getattr( + _mp_v55, 'main_grad', None, + ) + if _mgf_v55 is not None: + if not _v55_scaled_any: + _gn = float( + _mgf_v55.detach() + .float().norm() + .item() + ) + _v55_pre_sq += _gn * _gn + _mgf_v55.detach().mul_( + _v55_mult + ) + _gn2 = float( + _mgf_v55.detach() + .float().norm() + .item() + ) + _v55_post_sq += ( + _gn2 * _gn2 + ) + _v55_scaled_any = True + else: + _mgf_v55.detach().mul_( + _v55_mult + ) + except Exception: + pass + if _v55_scaled_any: + _v55_n_scaled += 1 + _v55_pre_norm = _v55_pre_sq ** 0.5 + _v55_post_norm = _v55_post_sq ** 0.5 + try: + self.logger.info( + '[MTPLRBoost-v55] mult=%.4f ' + 'n_scaled=%d mtp_grad_norm_pre=%.6e ' + 'mtp_grad_norm_post=%.6e', + _v55_mult, _v55_n_scaled, + _v55_pre_norm, _v55_post_norm, + ) + self.logger.info( + '[SpecDecFlow-v54] stage=lr_boost ' + 'mult=%.4f n_scaled=%d', + _v55_mult, _v55_n_scaled, + ) + except Exception: + pass + except Exception as _e_v55b: + try: + self.logger.warning( + '[MTPLRBoost-v55] boost failed: %r', _e_v55b, + ) + except Exception: + pass with trace_scope("megatron_engine.step"): update_successful, grad_norm, _ = self.optimizer.step() # [SpecDecFlow-v54] POST_STEP stage — per-MTP-param delta From b1c1c195bb611a99695f3c80737962abee31965b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 00:43:04 +0800 Subject: [PATCH 127/140] feat(megatron_engine): add --- areal/engine/megatron_engine.py | 329 +++++++++++++++++++++++++++++++- 1 file changed, 328 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 4864848b71..0fefb825b2 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -213,6 +213,7 @@ def __init__(self, config: TrainEngineConfig): "v53:MTPSharedWeightIsolate(detach output_weight for MTP output_layer)", "v54:MTPFreezeGate+DraftEMA+SpecDecFlowLog(AREAL_MTP_V54_FREEZE[default=0],AREAL_MTP_V54_DRAFT_EMA[default=0.0],AREAL_MTP_V54_SPEC_FLOW_LOG[default=1])", "v55:MTPLRBoost(AREAL_MTP_V55_MTP_LR_BOOST[default=1.0])", + "v56:MTPShipSummaryFix+GradTrace+LossTrace(AREAL_MTP_V56_GRAD_TRACE[default=1],AREAL_MTP_V56_LOSS_TRACE[default=1])", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -1669,6 +1670,205 @@ def _hook(_p): ) except Exception: pass + # [MTPGradTrace-v56] Detailed per-MTP-param grad trace. + # Captures `.grad`, `.main_param.grad`, and `.main_param.main_grad` + # exactly as they arrive from backward, BEFORE the v54 freeze + # block (which would zero them) and BEFORE the v55 LR boost + # block (which would scale them). Default ON: gated by + # AREAL_MTP_V56_GRAD_TRACE (default='1'). + try: + import os as _os_v56g + _v56_grad_on = ( + _os_v56g.environ.get( + 'AREAL_MTP_V56_GRAD_TRACE', '1', + ) == '1' + and getattr(self, 'enable_mtp_training', False) + and getattr(self, 'model', None) is not None + ) + if _v56_grad_on: + try: + import torch as _torch_v56g + import torch.distributed as _dist_v56g + _v56g_rank = ( + _dist_v56g.get_rank() + if _dist_v56g.is_initialized() else 0 + ) + except Exception: + _torch_v56g = None + _v56g_rank = 0 + _v56g_emb_ptrs = set() + try: + for _mod_e in self.model: + for _ne, _pe in _mod_e.named_parameters(): + if ( + 'embedding' in _ne + or 'word_embeddings' in _ne + ): + try: + _v56g_emb_ptrs.add(int(_pe.data_ptr())) + except Exception: + pass + except Exception: + pass + _v56g_n = 0 + _v56g_n_with_grad = 0 + _v56g_n_with_main_grad = 0 + _v56g_n_shared = 0 + _v56g_any_nan = False + _v56g_any_inf = False + for _mod_v56g in self.model: + for _n_v56g, _p_v56g in ( + _mod_v56g.named_parameters() + ): + if not ( + '.mtp.' in _n_v56g + or '.mtp_layers.' in _n_v56g + or '.enorm' in _n_v56g + or '.hnorm' in _n_v56g + or '.eh_proj' in _n_v56g + or '.shared_head.' in _n_v56g + ): + continue + _v56g_n += 1 + _g = getattr(_p_v56g, 'grad', None) + _g_present = _g is not None + _g_dtype = str(getattr(_g, 'dtype', None)) + _g_numel = ( + int(_g.numel()) if _g_present else 0 + ) + _g_norm = -1.0 + _g_amax = -1.0 + _g_isfinite = True + if _g_present: + try: + _gd = _g.detach().float() + _g_norm = float(_gd.norm().item()) + _g_amax = float( + _gd.abs().max().item() + ) + _g_isfinite = bool( + _torch_v56g.isfinite(_gd).all().item() + ) if _torch_v56g is not None else True + if _torch_v56g is not None: + if bool( + _torch_v56g.isnan(_gd).any().item() + ): + _v56g_any_nan = True + if bool( + _torch_v56g.isinf(_gd).any().item() + ): + _v56g_any_inf = True + _v56g_n_with_grad += 1 + except Exception: + pass + _mp = getattr(_p_v56g, 'main_param', None) + _mp_present = _mp is not None + _mp_dtype = str(getattr(_mp, 'dtype', None)) + _mp_grad = ( + getattr(_mp, 'grad', None) + if _mp_present else None + ) + _mp_grad_present = _mp_grad is not None + _mp_grad_norm = -1.0 + if _mp_grad_present: + try: + _mp_grad_norm = float( + _mp_grad.detach().float() + .norm().item() + ) + except Exception: + pass + _main_grad = ( + getattr(_mp, 'main_grad', None) + if _mp_present else None + ) + _mg_present = _main_grad is not None + _mg_dtype = str(getattr(_main_grad, 'dtype', None)) + _mg_norm = -1.0 + _mg_amax = -1.0 + _mg_isfinite = True + if _mg_present: + try: + _mgd = _main_grad.detach().float() + _mg_norm = float(_mgd.norm().item()) + _mg_amax = float( + _mgd.abs().max().item() + ) + _mg_isfinite = bool( + _torch_v56g.isfinite(_mgd) + .all().item() + ) if _torch_v56g is not None else True + if _torch_v56g is not None: + if bool( + _torch_v56g.isnan(_mgd) + .any().item() + ): + _v56g_any_nan = True + if bool( + _torch_v56g.isinf(_mgd) + .any().item() + ): + _v56g_any_inf = True + _v56g_n_with_main_grad += 1 + except Exception: + pass + _shared = False + try: + _shared = ( + int(_p_v56g.data_ptr()) + in _v56g_emb_ptrs + ) + except Exception: + pass + if _shared: + _v56g_n_shared += 1 + _gf = getattr(_p_v56g, 'grad_fn', None) + self.logger.info( + '[MTPGradTrace-v56] rank=%d name=%s ' + 'grad_present=%s grad_dtype=%s ' + 'grad_numel=%d grad_norm=%.6e ' + 'grad_amax=%.6e grad_isfinite=%s ' + 'main_param_present=%s ' + 'main_param_dtype=%s ' + 'main_param_grad_present=%s ' + 'main_param_grad_norm=%.6e ' + 'main_grad_present=%s ' + 'main_grad_dtype=%s ' + 'main_grad_norm=%.6e ' + 'main_grad_amax=%.6e ' + 'main_grad_isfinite=%s ' + 'grad_fn_present=%s requires_grad=%s ' + 'is_leaf=%s shared_tensor=%s', + _v56g_rank, _n_v56g, + str(_g_present), _g_dtype, + _g_numel, _g_norm, + _g_amax, str(_g_isfinite), + str(_mp_present), _mp_dtype, + str(_mp_grad_present), _mp_grad_norm, + str(_mg_present), _mg_dtype, + _mg_norm, _mg_amax, str(_mg_isfinite), + str(_gf is not None), + str(bool(_p_v56g.requires_grad)), + str(bool(_p_v56g.is_leaf)), + str(_shared), + ) + if _v56g_rank == 0: + self.logger.info( + '[MTPGradTrace-v56] summary n_mtp=%d ' + 'n_with_grad=%d n_with_main_grad=%d ' + 'n_shared_tensor=%d any_nan=%s any_inf=%s', + _v56g_n, _v56g_n_with_grad, + _v56g_n_with_main_grad, _v56g_n_shared, + str(_v56g_any_nan), str(_v56g_any_inf), + ) + except Exception as _e_v56g: + try: + self.logger.warning( + '[MTPGradTrace-v56] grad trace failed: %r', + _e_v56g, + ) + except Exception: + pass # [MTPFreezeGate-v54] Disambiguation/mitigation control. # When AREAL_MTP_V54_FREEZE=1 (default '0'=off), zero every # MTP parameter's .grad AND its main_param.grad/main_grad @@ -1908,6 +2108,108 @@ def _hook(_p): ) except Exception: pass + # [MTPLossTrace-v56] Best-effort defensive trace of any MTP + # loss state stored on `self`, run right before optimizer.step(). + # Gated by AREAL_MTP_V56_LOSS_TRACE (default='1'). + try: + import os as _os_v56l + _v56_loss_on = ( + _os_v56l.environ.get( + 'AREAL_MTP_V56_LOSS_TRACE', '1', + ) == '1' + ) + if _v56_loss_on: + try: + import torch as _torch_v56l + except Exception: + _torch_v56l = None + _v56l_keys = [] + _v56l_found = [] + try: + _v56l_attrs = [ + _a for _a in dir(self) + if ( + ('mtp' in _a.lower() + and 'loss' in _a.lower()) + or _a in ( + 'total_loss', '_last_mtp_loss', + 'mtp_loss', + '_mtp_loss_for_backward', + '_mtp_loss_value', + ) + ) + ] + except Exception: + _v56l_attrs = [] + for _a in _v56l_attrs: + try: + _v = getattr(self, _a, None) + except Exception: + continue + if _v is None: + continue + _v56l_keys.append(_a) + _is_tensor = ( + _torch_v56l is not None + and isinstance(_v, _torch_v56l.Tensor) + ) + if _is_tensor: + try: + _val = ( + float(_v.detach().float().mean().item()) + if _v.numel() > 0 else float('nan') + ) + except Exception: + _val = float('nan') + _v56l_found.append(_a) + try: + self.logger.info( + '[MTPLossTrace-v56] attr=%s ' + 'kind=tensor value=%.6e dtype=%s ' + 'numel=%d requires_grad=%s ' + 'grad_fn_present=%s', + _a, _val, str(_v.dtype), + int(_v.numel()), + str(bool(_v.requires_grad)), + str( + getattr(_v, 'grad_fn', None) + is not None + ), + ) + except Exception: + pass + elif isinstance(_v, (int, float)): + _v56l_found.append(_a) + try: + self.logger.info( + '[MTPLossTrace-v56] attr=%s ' + 'kind=scalar value=%s', + _a, str(_v), + ) + except Exception: + pass + elif isinstance(_v, (list, tuple)): + try: + self.logger.info( + '[MTPLossTrace-v56] attr=%s ' + 'kind=%s len=%d', + _a, type(_v).__name__, len(_v), + ) + except Exception: + pass + self.logger.info( + '[MTPLossTrace-v56] found=%s keys=%s', + str(bool(_v56l_found)), + str(_v56l_keys), + ) + except Exception as _e_v56l: + try: + self.logger.warning( + '[MTPLossTrace-v56] loss trace failed: %r', + _e_v56l, + ) + except Exception: + pass with trace_scope("megatron_engine.step"): update_successful, grad_norm, _ = self.optimizer.step() # [SpecDecFlow-v54] POST_STEP stage — per-MTP-param delta @@ -4248,8 +4550,18 @@ def _update_bucket_weights_from_distributed( _v54_ship_first = None _v54_ship_first_l2 = -1.0 _v54_ship_mtp_only = 0 + # [MTPShipSummaryFix-v56] Iterate the REAL MTP wire + # payload (`mtp_hf_tensors`, stashed on self at the + # `_update_weights_from_distributed` call site) instead + # of `converted_named_tensors` (which is the main-model + # bucket payload during the MTP wire path). This fixes + # the v54 ship_summary log that always reported + # n_mtp_shipped=0. + _v56_ship_iter = list( + getattr(self, '_v56_mtp_hf_tensors', []) or [] + ) for _v54_si, (_v54_sn, _v54_st) in enumerate( - converted_named_tensors + _v56_ship_iter ): _is_mtp = ( '.enorm' in _v54_sn @@ -4257,6 +4569,10 @@ def _update_bucket_weights_from_distributed( or '.eh_proj' in _v54_sn or '.shared_head.' in _v54_sn or '.mtp_layers.' in _v54_sn + # [MTPShipSummaryFix-v56] Items in mtp_hf_tensors + # are already MTP-only, so accept anything that + # came from that list as MTP wire payload. + or True ) if not _is_mtp: continue @@ -4300,6 +4616,7 @@ def _update_bucket_weights_from_distributed( ) self.logger.info( '[SpecDecFlow-v54] stage=ship_summary ' + '[MTPShipSummaryFix-v56] ' 'version=%s n_mtp_shipped=%d ' 'total_bytes=%d wire_norm=%.6e ' 'd_wire_norm=%.6e first=%s first_l2=%.6e ' @@ -6729,6 +7046,16 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"{len(converted_named_tensors)} tensors at elapsed=" f"{_diag_time.time() - _diag_t0:.3f}s" ) + # [MTPShipSummaryFix-v56] Stash the actual MTP wire payload + # (`mtp_hf_tensors`) on self so the v54 ship-stage diagnostic + # block inside `_update_bucket_weights_from_distributed` can + # iterate the *correct* list (the one truly broadcast to the + # inference engine), not `converted_named_tensors` which holds + # main-model bucket payload during the MTP wire path. + try: + self._v56_mtp_hf_tensors = list(mtp_hf_tensors) + except Exception: + self._v56_mtp_hf_tensors = [] self._update_bucket_weights_from_distributed(meta, converted_named_tensors) self.logger.info( f"[DiagUW] _update_bucket_weights_from_distributed completed at elapsed=" From 5c648e2d531dd82b36bba75e28d796f7bd3cd3cd Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 10:23:23 +0800 Subject: [PATCH 128/140] feat(engine): MTPStochasticRoundBf16 --- areal/engine/megatron_engine.py | 163 +++++++++++++++++++++++++++++++- 1 file changed, 158 insertions(+), 5 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 0fefb825b2..79c45b57c1 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -214,6 +214,7 @@ def __init__(self, config: TrainEngineConfig): "v54:MTPFreezeGate+DraftEMA+SpecDecFlowLog(AREAL_MTP_V54_FREEZE[default=0],AREAL_MTP_V54_DRAFT_EMA[default=0.0],AREAL_MTP_V54_SPEC_FLOW_LOG[default=1])", "v55:MTPLRBoost(AREAL_MTP_V55_MTP_LR_BOOST[default=1.0])", "v56:MTPShipSummaryFix+GradTrace+LossTrace(AREAL_MTP_V56_GRAD_TRACE[default=1],AREAL_MTP_V56_LOSS_TRACE[default=1])", + "v57:MTPStochasticRoundBf16(AREAL_MTP_V57_STOCHASTIC_ROUND[default=1],AREAL_MTP_V57_SR_MIN_DRIFT_RATIO[default=0.0])+ForceTickRatioFire+K2", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", ] @@ -6172,24 +6173,39 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: except Exception: _ft_on_v46 = True if _ft_on_v46: + # [v57] Tighten default K from 8 -> 2. + # Rationale: for high-magnitude LayerNorm + # tensors with sub-ULP drift, the natural + # stale-counter reaches K=8 only after + # training has already drifted far enough + # for main/draft mismatch to dominate + # accept_rate. K=2 bounds IPC staleness + # to a single sync. try: _ft_k_v46 = int( _os_v16.environ.get( 'AREAL_MTP_V46_FORCE_TICK_K', - '8', + '2', ) ) except Exception: - _ft_k_v46 = 8 + _ft_k_v46 = 2 + # [v57] Tighten default ratio 0.10 -> 0.05. + # resid_absmax grows ~drift per sync; at + # ratio=0.05 the ratio trigger fires once + # resid crosses 5%% of ULP, which is the + # smallest safe fraction where SR flip + # probability makes the ship_flips count + # statistically observable. try: _ft_ratio_v46 = float( _os_v16.environ.get( 'AREAL_MTP_V46_FORCE_TICK_RATIO', - '0.10', + '0.05', ) ) except Exception: - _ft_ratio_v46 = 0.10 + _ft_ratio_v46 = 0.05 if not hasattr( self, '_mtp_v46_stale' ): @@ -6233,8 +6249,17 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: >= _ft_ratio_v46 * _ft_ulp ) _ft_fired = False + # [v57] Fix defect: v46 computed + # _ft_trigger_ratio but never used it + # as a fire condition. The ratio trigger + # is the only path that fires for + # sub-ULP-drift LayerNorm tensors inside + # a normal (non-stale) sync cadence. if ( - _ft_trigger_stale + ( + _ft_trigger_stale + or _ft_trigger_ratio + ) and _ft_ulp > 0 ): # Promote _u by sign(resid or @@ -6292,11 +6317,139 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _ft_fired, ) ) + # [MTPStochasticRoundBf16-v57] + # Replace RNE with stochastic rounding for + # MTP payload so sub-ULP drift propagates + # across the tensor with expected flip + # count = numel * drift / ULP. Preserves + # long-run unbiasedness E[SR(x)] = x. + try: + _sr_on_v57 = ( + _os_v16.environ.get( + 'AREAL_MTP_V57_STOCHASTIC_ROUND', + '1', + ) == '1' + ) + except Exception: + _sr_on_v57 = True + _sr_applied_v57 = False + _sr_drift_ratio_v57 = -1.0 + if _sr_on_v57 and _u.dtype == _torch_v16.float32: + try: + _sr_min_ratio_v57 = float( + _os_v16.environ.get( + 'AREAL_MTP_V57_SR_MIN_DRIFT_RATIO', + '0.0', + ) + ) + except Exception: + _sr_min_ratio_v57 = 0.0 + try: + # Element-wise bf16 ULP derived + # from each element's magnitude. + # bf16 ULP(x) = 2^(e_x - 7) where + # 2^e_x <= |x| < 2^(e_x+1). For + # |x|=0 we use the tensor's global + # ulp_max as a fallback (mostly + # irrelevant since 0 rounds to 0). + _u_abs = _u.abs() + _nz_mask = _u_abs > 0 + # log2 is safe only on positives. + _log2u = _torch_v16.where( + _nz_mask, + _torch_v16.log2( + _torch_v16.where( + _nz_mask, + _u_abs, + _torch_v16.ones_like(_u_abs), + ) + ), + _torch_v16.zeros_like(_u_abs), + ) + _e_elem = _torch_v16.floor(_log2u) + _ulp_elem = _torch_v16.pow( + _torch_v16.full_like( + _e_elem, 2.0 + ), + _e_elem - 7.0, + ) + # For zero elements use tensor-level + # ulp (still zero contribution). + _ulp_elem = _torch_v16.where( + _nz_mask, + _ulp_elem, + _torch_v16.full_like( + _ulp_elem, + max(_ft_ulp, 0.0) if _ft_on_v46 else 0.0, + ), + ) + # Drift-gating check (optional). + _sr_enable_this = True + if _sr_min_ratio_v57 > 0: + try: + _drift_abs_max_v57 = 0.0 + if _r_prev is not None and _r_prev.shape == _u.shape: + _drift_abs_max_v57 = float( + _r_prev.abs().max().item() + ) + _ulp_global = float( + _ulp_elem.max().item() + ) + if _ulp_global > 0: + _sr_drift_ratio_v57 = ( + _drift_abs_max_v57 / _ulp_global + ) + else: + _sr_drift_ratio_v57 = 0.0 + # If RNE is already naturally + # flipping, skip SR to keep + # training deterministic. + if _sr_drift_ratio_v57 >= _sr_min_ratio_v57: + _sr_enable_this = False + except Exception: + _sr_enable_this = True + if _sr_enable_this: + # Dither: u ~ Uniform[-0.5, 0.5] + # per-element, scale by ulp_elem + # so that RNE(_u + u*ulp_elem) + # realises the SR rounding. + _dither = ( + _torch_v16.rand_like(_u) - 0.5 + ) * _ulp_elem + _u = _u + _dither + _sr_applied_v57 = True + except Exception as _e_sr_v57: + try: + self.logger.info( + '[MTPStochasticRoundBf16-v57] ' + 'SR failed name=%s err=%r; ' + 'falling back to RNE.', + _nm_sd, _e_sr_v57, + ) + except Exception: + pass + _sr_applied_v57 = False # RNE fp32 -> bf16 and retrieve actual # quantized fp32 value for residual calc. + # When v57 SR applied, the dithered _u + # combined with RNE here is mathematically + # equivalent to per-element stochastic + # rounding of the original fp32 master. _bf16 = _u.to(_torch_v16.bfloat16) _bb = _bf16.float() _new_res = (_u - _bb).detach().clone() + if _sr_applied_v57: + try: + self.logger.info( + '[MTPStochasticRoundBf16-v57] ' + 'name=%s shape=%s numel=%d ' + 'drift_ratio=%.3e applied=True', + _nm_sd, tuple(_u.shape), + int(_u.numel()), + _sr_drift_ratio_v57, + ) + except Exception: + pass # Diagnostic: count elements whose bf16 # representation differs from the plain # RNE(_tn_sd) baseline (i.e. how many were From a123305449bc29699606be8e510df04d355c7a32 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 21:50:14 +0800 Subject: [PATCH 129/140] feat(megatron_engine): Align --- areal/engine/megatron_engine.py | 293 ++++++++++++++++++++++++++++++-- 1 file changed, 280 insertions(+), 13 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 79c45b57c1..3dec7432af 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -217,6 +217,10 @@ def __init__(self, config: TrainEngineConfig): "v57:MTPStochasticRoundBf16(AREAL_MTP_V57_STOCHASTIC_ROUND[default=1],AREAL_MTP_V57_SR_MIN_DRIFT_RATIO[default=0.0])+ForceTickRatioFire+K2", "v17:MTPNativeAutoScaler+ConsumerBypass" "(AREAL_MTP_NATIVE_AUTOSCALER,autograd_in_graph)", + "v58:MTPSlimeAlign(AREAL_MTP_SLIME_ALIGN[default=1]):" + "disable Path3-detach/v53-weight-detach/v52-cap/" + "FIFO-append/v50-gradFp32/v57-SR; set_loss_scale=" + "loss_scale/num_mb (Megatron-Core native = slime)", ] _banner_flags = { "AREAL_MTP_FP32_BROADCAST": @@ -231,7 +235,35 @@ def __init__(self, config: TrainEngineConfig): "AREAL_MTP_V30_DIAG": _os_banner.environ.get( "AREAL_MTP_V30_DIAG", "1"), + "AREAL_MTP_SLIME_ALIGN": + _os_banner.environ.get( + "AREAL_MTP_SLIME_ALIGN", "1"), } + try: + _slime_align_on = ( + _os_banner.environ.get( + "AREAL_MTP_SLIME_ALIGN", "1") == "1" + ) + self.logger.info( + "[MTPSlimeAlign] AREAL_MTP_SLIME_ALIGN=%s. " + "When ON: A) Path3 detach SKIPPED, " + "B) output_weight NOT detached, " + "C) v52 SourceLossCap DISABLED, " + "D) FIFO append SKIPPED, " + "E) set_loss_scale=loss_scale/num_mb, " + "G) v50 MTPGradFp32Coerce DISABLED, " + "H) v57 StochasticRoundBf16 DISABLED. " + "This restores Megatron-Core 0.16.0 native MTP " + "semantics (= slime), so " + "mtp_loss_scaling_factor=0.2 carries the same " + "meaning as in slime.", + _slime_align_on, + ) + except Exception as _e_sa: + self.logger.warning( + "[MTPSlimeAlign] banner log failed: %s", + _e_sa, + ) self.logger.info( "[MTPVersionBanner] tags=%s flags=%s", ",".join(_banner_tags), _banner_flags, @@ -1452,9 +1484,24 @@ def _hook(_p): # Gate: AREAL_MTP_V50_GRAD_FP32_COERCE (default "1"). try: import os as _os_v50g + # [MTPSlimeAlign] disable v50 fp32 coerce when slime-align is ON; + # slime/Megatron-Core native does not upcast MTP main_grad. + _v50_slime = ( + _os_v50g.environ.get('AREAL_MTP_SLIME_ALIGN', '1') == '1' + ) _v50_gfp32 = ( _os_v50g.environ.get('AREAL_MTP_V50_GRAD_FP32_COERCE', '1') == '1' + and not _v50_slime ) + if _v50_slime and not getattr(self, '_v58_v50_logged', False): + try: + self.logger.info( + '[MTPSlimeAlign] v50 MTPGradFp32Coerce DISABLED ' + '(slime/native does not upcast MTP main_grad).' + ) + self._v58_v50_logged = True + except Exception: + pass except Exception: _v50_gfp32 = True if ( @@ -2962,11 +3009,46 @@ def _patched_postprocess( # output_weight ...)` call below is LEFT UNTOUCHED so GRPO # gradient on lm_head / embedding is preserved exactly. # ---------------------------------------------------------------- - _mtp_output_weight_v53 = ( - output_weight.detach() - if output_weight is not None - else None - ) + # [MTPSlimeAlign] When slime-align is ON, pass + # the un-detached shared output_weight, exactly + # like Megatron-Core 0.16.0 native + # gpt_model.py:_postprocess and slime. This + # restores MTP CE -> shared lm_head/embedding + # gradient flow, which is essential for the + # main policy to track the draft distribution. + try: + import os as _os_v58_b + _v58_slime_b = ( + _os_v58_b.environ.get( + 'AREAL_MTP_SLIME_ALIGN', '1' + ) == '1' + ) + except Exception: + _v58_slime_b = True + if _v58_slime_b: + _mtp_output_weight_v53 = output_weight + if ( + not getattr( + _engine_ref, + '_v58_b_logged', False) + ): + try: + _logger.info( + '[MTPSlimeAlign] B) ' + 'output_weight passed ' + 'un-detached to MTP ' + 'output_layer (slime/' + 'native).' + ) + _engine_ref._v58_b_logged = True + except Exception: + pass + else: + _mtp_output_weight_v53 = ( + output_weight.detach() + if output_weight is not None + else None + ) if mtp_in_postprocess: hidden_states = self_model.mtp( @@ -3156,8 +3238,35 @@ def _patched_postprocess( # Disable: AREAL_MTP_V52_LOSS_CAP_RATIO=0 try: import os as _os_v52s - _v52_ratio = float(_os_v52s.environ.get( - 'AREAL_MTP_V52_LOSS_CAP_RATIO', '2.0')) + # [MTPSlimeAlign] force cap ratio + # to 0 when slime-align is ON; + # native Megatron-Core has no + # source-side loss cap. + if _os_v52s.environ.get( + 'AREAL_MTP_SLIME_ALIGN', '1' + ) == '1': + _v52_ratio = 0.0 + if not getattr( + _engine_ref, + '_v58_c_logged', False + ): + try: + _logger.info( + '[MTPSlime' + 'Align] C) ' + 'v52 Source' + 'LossCap ' + 'DISABLED ' + '(ratio=0, ' + 'slime/' + 'native).' + ) + _engine_ref._v58_c_logged = True + except Exception: + pass + else: + _v52_ratio = float(_os_v52s.environ.get( + 'AREAL_MTP_V52_LOSS_CAP_RATIO', '2.0')) except Exception: _v52_ratio = 2.0 if _v52_ratio > 0.0: @@ -3199,7 +3308,36 @@ def _patched_postprocess( '[MTPSourceLossCap-v52] failed: %s', _e_v52s, ) - _engine_ref._mtp_loss_for_backward.append(_mtp_loss_to_store) + # [MTPSlimeAlign] D) skip FIFO append + # when slime-align is ON; native MC has + # no scalar FIFO -- gradient injection + # is handled solely by + # MTPLossAutoScaler.apply below. + try: + import os as _os_v58_d + _v58_slime_d = ( + _os_v58_d.environ.get( + 'AREAL_MTP_SLIME_ALIGN', + '1') == '1' + ) + except Exception: + _v58_slime_d = True + if not _v58_slime_d: + _engine_ref._mtp_loss_for_backward.append(_mtp_loss_to_store) + elif not getattr( + _engine_ref, + '_v58_d_logged', False + ): + try: + _logger.info( + '[MTPSlimeAlign] D) ' + 'FIFO append SKIPPED ' + '(slime/native uses ' + 'autograd-only path).' + ) + _engine_ref._v58_d_logged = True + except Exception: + pass # --- BEGIN --- # Reproduce Megatron-native behaviour: @@ -3259,11 +3397,82 @@ def _patched_postprocess( # loss_scale via the outer # loss * loss_scale contract, # so only 1/num_mb is needed here. - _MTPLossAutoScaler_v17.set_loss_scale( - _torch_v17.tensor( - 1.0 / float(_num_mb_v17) + # [MTPSlimeAlign] E) match + # Megatron-Core schedules.py: + # loss_scale = grad_scale_func(1.0) + # set_loss_scale(loss_scale / num_microbatches) + # Falls back to 1/num_mb only + # when slime-align is OFF, to + # preserve legacy behaviour. + try: + import os as _os_v58_e + _v58_slime_e = ( + _os_v58_e.environ.get( + 'AREAL_MTP_SLIME_ALIGN', + '1') == '1' + ) + except Exception: + _v58_slime_e = True + if _v58_slime_e: + try: + _gsf_e = getattr( + self_model.config, + 'grad_scale_func', + None, + ) + _ls_e = ( + _gsf_e( + _torch_v17.ones( + 1, + device=hidden_states.device, + ) + ) + if _gsf_e is not None + else _torch_v17.ones( + 1, + device=hidden_states.device, + ) + ) + except Exception: + _ls_e = _torch_v17.ones( + 1, + device=hidden_states.device, + ) + _MTPLossAutoScaler_v17.set_loss_scale( + _ls_e / float(_num_mb_v17) + ) + if not getattr( + _engine_ref, + '_v58_e_logged', False + ): + try: + _logger.info( + '[MTPSlime' + 'Align] E) ' + 'set_loss_scale' + '=loss_scale/' + 'num_mb (= ' + 'Megatron-Core ' + 'schedules.py ' + ': %s / %d).', + float( + _ls_e.item() + if hasattr( + _ls_e, + 'item') + else _ls_e + ), + int(_num_mb_v17), + ) + _engine_ref._v58_e_logged = True + except Exception: + pass + else: + _MTPLossAutoScaler_v17.set_loss_scale( + _torch_v17.tensor( + 1.0 / float(_num_mb_v17) + ) ) - ) try: _d06_step = getattr( _engine_ref, @@ -3498,8 +3707,44 @@ def _mtp_backward_hook(grad, _lg=_logger, _gs=_gs_v5): _orig_postprocess, ) - # Path 3: patch _get_embeddings for embedding detach + # Path 3: patch _get_embeddings for embedding detach. + # [MTPSlimeAlign] A) Skip Path-3 detach when + # slime-align is ON; native Megatron-Core uses + # `make_viewless_tensor(..., keep_graph=True)` + # which preserves the gradient flow through the + # decoder_input/hidden_states into the main + # embedding & backbone -- this is precisely the + # mechanism that makes slime's + # `mtp_loss_scaling_factor=0.2` an effective + # main-policy regulariser. + try: + import os as _os_v58_a + _v58_slime_a = ( + _os_v58_a.environ.get( + 'AREAL_MTP_SLIME_ALIGN', '1') == '1' + ) + except Exception: + _v58_slime_a = True _mtp_block = getattr(_unwrapped, "mtp", None) + if _v58_slime_a: + if not getattr( + self, '_v58_a_logged', False + ): + try: + self.logger.info( + '[MTPSlimeAlign] A) Path-3 ' + '_get_embeddings detach SKIPPED. ' + 'Native Megatron-Core preserves ' + 'decoder_input/hidden_states via ' + 'make_viewless_tensor(keep_graph=' + 'True), letting MTP CE backward ' + 'flow into the main embedding & ' + 'backbone (slime semantics).' + ) + self._v58_a_logged = True + except Exception: + pass + _mtp_block = None # disables the patch loop if _mtp_block is not None and hasattr(_mtp_block, "layers"): for _layer in _mtp_block.layers: _orig_get_emb = _layer._get_embeddings @@ -6324,12 +6569,34 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: # count = numel * drift / ULP. Preserves # long-run unbiasedness E[SR(x)] = x. try: + # [MTPSlimeAlign] H) disable v57 SR when + # slime-align is ON; slime/native bf16 + # path uses RNE (round-nearest-even) only. + _v57_slime = ( + _os_v16.environ.get( + 'AREAL_MTP_SLIME_ALIGN', '1' + ) == '1' + ) _sr_on_v57 = ( _os_v16.environ.get( 'AREAL_MTP_V57_STOCHASTIC_ROUND', '1', ) == '1' + and not _v57_slime ) + if _v57_slime and not getattr( + self, '_v58_h_logged', False + ): + try: + self.logger.info( + '[MTPSlimeAlign] H) v57 ' + 'StochasticRoundBf16 ' + 'DISABLED (slime/native uses ' + 'RNE only).' + ) + self._v58_h_logged = True + except Exception: + pass except Exception: _sr_on_v57 = True _sr_applied_v57 = False From e331aa3a38d99a923efe240e37d0bdd47042433f Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 23:17:19 +0800 Subject: [PATCH 130/140] fix(mcore): hf convert --- areal/engine/megatron_engine.py | 59 ++++++++++ areal/models/mcore/hf_load.py | 24 ++++ areal/models/mcore/mimo_mtp_hf_mapping.py | 134 ++++++++++++++++++++++ 3 files changed, 217 insertions(+) create mode 100644 areal/models/mcore/mimo_mtp_hf_mapping.py diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 3dec7432af..b074466742 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -4503,6 +4503,30 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: # --- MTP independent learning rate --- _mtp_lr_config_overrides = None _mtp_lr_scale = getattr(self.optimizer_config, 'mtp_lr_scale', 1.0) + # [P2-MTPShareParamGroup] When the MTP-only param group is activated + # via ParamKey(name=("*.mtp.*",)), Megatron 0.16 DistributedOptimizer + # only shards that (small) group across a subset of DP ranks, leaving + # the other ranks with ``param.main_param = None``. That breaks the + # weight-ship path because _collect_param(..) returns None on those + # ranks, so ``mtp_hf_tensors`` stays empty and sglang draft never + # gets updated. Default behaviour of this patch is to force MTP to + # share the main param group (mtp_lr_scale coerced to 1.0); opt-out + # via AREAL_MTP_SHARE_PARAM_GROUP=0. + _v59_share_pg = (os.environ.get( + "AREAL_MTP_SHARE_PARAM_GROUP", "1") == "1") + if ( + self.enable_mtp_training + and _v59_share_pg + and _mtp_lr_scale != 1.0 + ): + self.logger.warning( + "[MTPShareParamGroup-P2] overriding mtp_lr_scale=%.3f -> 1.0 " + "so MTP parameters share the main param group and every DP " + "rank holds a master-param shard. Set " + "AREAL_MTP_SHARE_PARAM_GROUP=0 to restore the split.", + _mtp_lr_scale, + ) + _mtp_lr_scale = 1.0 if self.enable_mtp_training and _mtp_lr_scale != 1.0: try: from megatron.core.optimizer.optimizer_config import ParamKey @@ -6135,6 +6159,41 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _sys_v24m.stdout.flush() except Exception: pass + # [P3-MTPShipFallback] When _collect_param returned None on + # this rank (typically because the MTP-only param group left + # no master shard here, or the fp32-master fetch raised), + # fall back to a plain bf16 all-gather of ``param`` so that + # the wire payload for the draft model is never dropped + # silently. Opt-out via AREAL_MTP_SHIP_FALLBACK=0. + if ( + _collect_mtp_for_draft + and _mtp_param is None + and os.environ.get( + "AREAL_MTP_SHIP_FALLBACK", "1") == "1" + ): + try: + _fb_param, _ = self._collect_param(name, param) + if _fb_param is not None: + _mtp_param = _fb_param + self.logger.warning( + "[MTPShipFallback-P3] rank=%d name=%s " + "fell back to bf16 all-gather (fp32 " + "master unavailable on this rank).", + dist.get_rank(), name, + ) + else: + self.logger.error( + "[MTPShipFallback-P3] rank=%d name=%s " + "bf16 all-gather also returned None; " + "MTP tensor will be skipped.", + dist.get_rank(), name, + ) + except Exception as _e_p3_fb: + self.logger.error( + "[MTPShipFallback-P3] rank=%d name=%s " + "fallback raised: %r", + dist.get_rank(), name, _e_p3_fb, + ) if _collect_mtp_for_draft and _mtp_param is not None: _mtp_model_name = self.hf_config.model_type _prev_count = len(mtp_hf_tensors) diff --git a/areal/models/mcore/hf_load.py b/areal/models/mcore/hf_load.py index e8b76fbaab..3048e873e2 100644 --- a/areal/models/mcore/hf_load.py +++ b/areal/models/mcore/hf_load.py @@ -468,6 +468,30 @@ def load_weights_from_hf_with_mbridge_fast( for k, v in local_to_global_map.items() if "_extra_state" not in k } + # [P1-MTPCkptLoad] mbridge MiMoBridge does not translate + # ``mtp.layers.{idx}.*`` MCore-global keys to their HF counterparts + # (``model.mtp_layers.{idx}.*``). Without this augmentation the MTP + # head boots from random weights (per-token CE ≈ log(vocab)) which + # crushes spec_accept_rate. Apply MiMo MTP mapping in-place when the + # base bridge produced an empty list. + try: + from areal.models.mcore.mimo_mtp_hf_mapping import ( + augment_local_to_hf_map_with_mtp, + ) + _mtp_patched = augment_local_to_hf_map_with_mtp( + local_to_global_map, local_to_hf_map, logger=logger, + ) + if _mtp_patched: + logger.info( + "[MTPCkptLoad-P1] applied MTP HF-name mapping for " + "%d local keys (model_index=%d)", + _mtp_patched, model_index, + ) + except Exception as _e_p1_mtp: + logger.warning( + "[MTPCkptLoad-P1] augment failed: %r (MTP weights may " + "still be missing)", _e_p1_mtp, + ) if manual_tie_word_embedding: for k, v in local_to_hf_map.items(): if "lm_head.weight" in v: diff --git a/areal/models/mcore/mimo_mtp_hf_mapping.py b/areal/models/mcore/mimo_mtp_hf_mapping.py new file mode 100644 index 0000000000..894b669fd2 --- /dev/null +++ b/areal/models/mcore/mimo_mtp_hf_mapping.py @@ -0,0 +1,134 @@ +"""MiMo MTP HF name-mapping helper. + +The upstream ``mbridge`` ``MiMoBridge`` (as of PR#1176 HEAD) does NOT translate +MTP-layer local keys such as ``mtp.layers.0.enorm.weight`` into their +HuggingFace counterparts under ``model.mtp_layers.0.*``. As a result, the +MiMo-7B-RL checkpoint's MTP weights are silently skipped during +``load_weights_from_hf_with_mbridge_fast``, leaving the MTP head at random +initialisation (per-token CE \u2248 log(vocab)). + +This module provides a pure-data mapping that mirrors +``areal.engine.megatron_utils.megatron._convert_mimo_mtp_param`` (the MCore \u2192 HF +direction already used by the weight-ship path), so that checkpoint loading +can be fixed without modifying the mbridge package itself. + +Usage is limited to ``areal.models.mcore.hf_load`` which calls +``augment_local_to_hf_map_with_mtp`` after the bridge has populated the base +mapping. +""" +from __future__ import annotations + +import re +from typing import Dict, List + +# Matches both ``mtp.layers.{idx}.{rest}`` and the ``decoder.mtp_layers.{idx}.`` +# variant that a few megatron-core revisions emit. +_MTP_GLOBAL_RE = re.compile( + r"^(?:decoder\.)?mtp(?:\.layers|_layers)\.(\d+)\.(.+)$" +) + +# MCore MTP suffix -> HF suffix under ``model.mtp_layers.{idx}.``. +# Multi-valued entries are merged by the existing qkv / gate-up handling in +# ``hf_load._convert_hf_weights_to_mcore``. +_MTP_SUFFIX_MAP: Dict[str, object] = { + # MTP-specific layer norms and projections + "enorm.weight": "token_layernorm.weight", + "hnorm.weight": "hidden_layernorm.weight", + "eh_proj.weight": "input_proj.weight", + "final_layernorm.weight": "final_layernorm.weight", + + # transformer_layer.* (reused Qwen2 decoder block) + "transformer_layer.input_layernorm.weight": + "input_layernorm.weight", + "transformer_layer.self_attention.linear_qkv.layer_norm_weight": + "input_layernorm.weight", + "transformer_layer.self_attention.linear_qkv.weight": [ + "self_attn.q_proj.weight", + "self_attn.k_proj.weight", + "self_attn.v_proj.weight", + ], + "transformer_layer.self_attention.linear_qkv.bias": [ + "self_attn.q_proj.bias", + "self_attn.k_proj.bias", + "self_attn.v_proj.bias", + ], + "transformer_layer.self_attention.linear_proj.weight": + "self_attn.o_proj.weight", + + "transformer_layer.pre_mlp_layernorm.weight": + "post_attention_layernorm.weight", + "transformer_layer.mlp.linear_fc1.layer_norm_weight": + "post_attention_layernorm.weight", + "transformer_layer.mlp.linear_fc1.weight": [ + "mlp.gate_proj.weight", + "mlp.up_proj.weight", + ], + "transformer_layer.mlp.linear_fc2.weight": + "mlp.down_proj.weight", +} + + +def mtp_mcore_name_to_hf_names(global_name: str) -> List[str]: + """Return the HF keys matching one MCore MTP-global name. + + Returns an empty list if ``global_name`` does not look like an MTP entry + or has no explicit mapping rule (e.g. ``_extra_state`` tails, unknown + subcomponents - these are logged by the caller). + """ + m = _MTP_GLOBAL_RE.match(global_name) + if m is None: + return [] + idx, rest = m.group(1), m.group(2) + if rest.endswith("_extra_state"): + return [] + rule = _MTP_SUFFIX_MAP.get(rest) + if rule is None: + return [] + prefix = f"model.mtp_layers.{idx}." + if isinstance(rule, str): + return [prefix + rule] + return [prefix + s for s in rule] + + +def augment_local_to_hf_map_with_mtp( + local_to_global_map: Dict[str, str], + local_to_hf_map: Dict[str, List[str]], + logger=None, +) -> int: + """Inject MTP HF-name mappings into ``local_to_hf_map`` in-place. + + Iterates the local keys whose global name matches the MTP pattern. If + the existing mapping is empty (i.e. the base bridge did not know how to + translate it), populate it from :func:`mtp_mcore_name_to_hf_names`. + + Returns the number of local keys patched. A single [MTPCkptLoad-P1] + summary is emitted via ``logger`` for verification. + """ + patched = 0 + preview: List[str] = [] + for local_name, global_name in local_to_global_map.items(): + if "_extra_state" in local_name: + continue + m = _MTP_GLOBAL_RE.match(global_name) + if m is None: + continue + cur = local_to_hf_map.get(local_name) or [] + if cur: + continue + hf_names = mtp_mcore_name_to_hf_names(global_name) + if not hf_names: + continue + local_to_hf_map[local_name] = hf_names + patched += 1 + if len(preview) < 3: + preview.append(f"{local_name}->{hf_names}") + if logger is not None: + try: + logger.info( + "[MTPCkptLoad-P1] augment_local_to_hf_map_with_mtp " + "patched=%d examples=%s", + patched, preview, + ) + except Exception: + pass + return patched From 37e3ff757e5d41269dfd9a311b65d04a4e203e5b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 23:43:17 +0800 Subject: [PATCH 131/140] fix: d1 --- examples/math/gsm8k_grpo_megatron_mimo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/gsm8k_grpo_megatron_mimo.yaml b/examples/math/gsm8k_grpo_megatron_mimo.yaml index 7a3030deaf..433b5fb0b8 100644 --- a/examples/math/gsm8k_grpo_megatron_mimo.yaml +++ b/examples/math/gsm8k_grpo_megatron_mimo.yaml @@ -40,7 +40,7 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:d2p1t2" + backend: "megatron:d1p1t2" experiment_name: ${experiment_name} trial_name: ${trial_name} path: /workspace/models/MiMo-7B-RL From 4b65adf52dfff79c5b22b8485395b582c1efa7c8 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 00:14:07 +0800 Subject: [PATCH 132/140] fix(mcore/mimo_mtp_hf_mapping): fix convert --- areal/models/mcore/mimo_mtp_hf_mapping.py | 77 ++++++++++++++++------- 1 file changed, 56 insertions(+), 21 deletions(-) diff --git a/areal/models/mcore/mimo_mtp_hf_mapping.py b/areal/models/mcore/mimo_mtp_hf_mapping.py index 894b669fd2..b12ee305f6 100644 --- a/areal/models/mcore/mimo_mtp_hf_mapping.py +++ b/areal/models/mcore/mimo_mtp_hf_mapping.py @@ -2,15 +2,19 @@ The upstream ``mbridge`` ``MiMoBridge`` (as of PR#1176 HEAD) does NOT translate MTP-layer local keys such as ``mtp.layers.0.enorm.weight`` into their -HuggingFace counterparts under ``model.mtp_layers.0.*``. As a result, the -MiMo-7B-RL checkpoint's MTP weights are silently skipped during -``load_weights_from_hf_with_mbridge_fast``, leaving the MTP head at random -initialisation (per-token CE \u2248 log(vocab)). +HuggingFace counterparts under ``model.mtp_layers.0.*``. Worse, in practice +it falls back to the parent Qwen2 rules and returns a **non-empty but wrong** +list (e.g. ``model.layers.0.*``), which downstream silently ignores when the +index lookup fails. That left every non-``eh_proj`` MTP tensor at random +initialisation (per-token CE \u2248 log(vocab) \u2248 11.24) and dragged the +spec-decoding accept-rate below 0.30 after the first weight-ship. This module provides a pure-data mapping that mirrors ``areal.engine.megatron_utils.megatron._convert_mimo_mtp_param`` (the MCore \u2192 HF -direction already used by the weight-ship path), so that checkpoint loading -can be fixed without modifying the mbridge package itself. +direction already used by the weight-ship path), and - in v60 - unconditionally +OVERWRITES whatever the base bridge produced for MTP keys, aligning with +``slime/slime_plugins/mbridge/mimo.py::_weight_name_mapping_mcore_to_hf`` which +also hard-routes all ``mtp.*`` names through an MTP-specific converter. Usage is limited to ``areal.models.mcore.hf_load`` which calls ``augment_local_to_hf_map_with_mtp`` after the bridge has populated the base @@ -18,6 +22,7 @@ """ from __future__ import annotations +import os import re from typing import Dict, List @@ -97,37 +102,67 @@ def augment_local_to_hf_map_with_mtp( ) -> int: """Inject MTP HF-name mappings into ``local_to_hf_map`` in-place. - Iterates the local keys whose global name matches the MTP pattern. If - the existing mapping is empty (i.e. the base bridge did not know how to - translate it), populate it from :func:`mtp_mcore_name_to_hf_names`. + v60 behaviour (slime-aligned, see ``slime_plugins/mbridge/mimo.py``): + any local key whose global name matches the MTP pattern is **authoritatively + overwritten** with the MTP-specific HF names produced by this module. The + upstream bridge's Qwen2 default rules are discarded, because they point + at ``model.layers.{idx}.*`` keys that do not exist in the MiMo checkpoint + and were silently ignored by the downstream loader - the very reason the + MTP head kept booting at random initialisation. - Returns the number of local keys patched. A single [MTPCkptLoad-P1] - summary is emitted via ``logger`` for verification. + Opt-out: ``AREAL_MTP_P1_OVERWRITE=0`` reverts to v59 "only when empty" + behaviour for A/B testing. + + Returns the number of local keys patched. A single ``[MTPCkptLoad-P1]`` + summary with ``overwritten_nonempty`` / ``filled_empty`` / ``skipped_no_rule`` + breakdown is emitted via ``logger`` for verification. """ + overwrite = os.environ.get("AREAL_MTP_P1_OVERWRITE", "1") == "1" patched = 0 - preview: List[str] = [] + filled_empty = 0 + overwritten_nonempty = 0 + skipped_no_rule = 0 + preview_filled: List[str] = [] + preview_overwritten: List[str] = [] for local_name, global_name in local_to_global_map.items(): if "_extra_state" in local_name: continue m = _MTP_GLOBAL_RE.match(global_name) if m is None: continue - cur = local_to_hf_map.get(local_name) or [] - if cur: - continue hf_names = mtp_mcore_name_to_hf_names(global_name) if not hf_names: + skipped_no_rule += 1 continue - local_to_hf_map[local_name] = hf_names - patched += 1 - if len(preview) < 3: - preview.append(f"{local_name}->{hf_names}") + cur = local_to_hf_map.get(local_name) or [] + if cur: + if not overwrite: + # v59 compatibility mode + continue + # v60: authoritative overwrite (slime-aligned) + local_to_hf_map[local_name] = hf_names + overwritten_nonempty += 1 + patched += 1 + if len(preview_overwritten) < 3: + preview_overwritten.append( + f"{local_name}: {cur}->{hf_names}" + ) + else: + local_to_hf_map[local_name] = hf_names + filled_empty += 1 + patched += 1 + if len(preview_filled) < 3: + preview_filled.append(f"{local_name}->{hf_names}") if logger is not None: try: logger.info( "[MTPCkptLoad-P1] augment_local_to_hf_map_with_mtp " - "patched=%d examples=%s", - patched, preview, + "patched=%d (overwritten_nonempty=%d, filled_empty=%d, " + "skipped_no_rule=%d) overwrite_mode=%s " + "preview_overwritten=%s preview_filled=%s", + patched, overwritten_nonempty, filled_empty, + skipped_no_rule, overwrite, + preview_overwritten, preview_filled, ) except Exception: pass From 2e8d18897a70e86c93fb10f9552e592b9a5fa4bf Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 11:56:40 +0800 Subject: [PATCH 133/140] feat: mtp weight log --- areal/engine/megatron_engine.py | 231 ++++++++++++++++++++++++++++++++ areal/models/mcore/hf_load.py | 51 +++++++ 2 files changed, 282 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index b074466742..ddac1d0c94 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -3105,6 +3105,88 @@ def _patched_postprocess( mtp_layer_number, _gs_fwd, _mtp_hs.requires_grad, list(_mtp_hs.shape), _mtp_hs_gfn, hidden_states.requires_grad) + # [MTPFwdWeightAudit-v61] log live MTP + # weight statistics every 100 steps to + # detect silent corruption between load + # time (MTPLoad/MTPPreScan) and forward. + try: + _v61w_gs = getattr( + _engine_ref, + '_global_step', 0, + ) + if (_mtp_diag_mb_counter[0] == 0 + and (_v61w_gs <= 3 + or _v61w_gs % 100 == 0)): + _v61w_mtp = getattr( + self_model, 'mtp', None, + ) + if _v61w_mtp is not None: + for _v61w_pn in ( + 'enorm.weight', + 'hnorm.weight', + 'eh_proj.weight', + ): + try: + _v61w_p = ( + _v61w_mtp + .layers[0] + ) + for _v61w_part in _v61w_pn.split('.'): + _v61w_p = getattr( + _v61w_p, + _v61w_part, + ) + _v61w_pf = _v61w_p.detach().float() + _v61w_am = float(_v61w_pf.abs().mean().item()) + _v61w_ax = float(_v61w_pf.abs().max().item()) + _v61w_l2 = float(_v61w_pf.norm().item()) + _v61w_first8 = [ + float(x) for x in + _v61w_pf.reshape(-1)[:8].tolist() + ] + _logger.info( + '[MTPFwdWeightAudit-v61] ' + 'step=%d mtp.layers.0.%s ' + 'dtype=%s shape=%s ' + 'abs_mean=%.6e abs_max=%.6e ' + 'l2=%.6e first8=%s', + _v61w_gs, _v61w_pn, + str(_v61w_p.dtype), + str(tuple(_v61w_p.shape)), + _v61w_am, _v61w_ax, + _v61w_l2, + str(_v61w_first8), + ) + except Exception: + continue + # also probe output_weight + try: + _v61w_ow = _mtp_output_weight_v53 + if _v61w_ow is not None: + _v61w_owf = _v61w_ow.detach().float() + _logger.info( + '[MTPFwdWeightAudit-v61] ' + 'step=%d output_weight ' + 'dtype=%s shape=%s ' + 'abs_mean=%.6e abs_max=%.6e ' + 'l2=%.6e', + _v61w_gs, + str(_v61w_ow.dtype), + str(tuple(_v61w_ow.shape)), + float(_v61w_owf.abs().mean().item()), + float(_v61w_owf.abs().max().item()), + float(_v61w_owf.norm().item()), + ) + except Exception: + pass + except Exception as _e_v61w: + try: + _logger.info( + '[MTPFwdWeightAudit-v61] ' + 'failure: %r', _e_v61w, + ) + except Exception: + pass mtp_logits, _ = self_model.output_layer( _mtp_hs, # [MTPSharedWeightIsolate-v53] detached weight @@ -6194,6 +6276,45 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: "fallback raised: %r", dist.get_rank(), name, _e_p3_fb, ) + # [MTPShipEnumTrace-v61] log per-MTP-param ship enumeration + # ENTER. Captures whether MTP path will collect this tensor, + # the param's bf16 statistics, and shape — independent of + # later HF-name expansion. + try: + if _collect_mtp_for_draft and ('mtp' in name): + _v61_pa = _mtp_param if _mtp_param is not None else None + if _v61_pa is not None: + _v61_pf = _v61_pa.detach().float() + _v61_am = float(_v61_pf.abs().mean().item()) + _v61_ax = float(_v61_pf.abs().max().item()) + _v61_l2 = float(_v61_pf.norm().item()) + _v61_n = int(_v61_pa.numel()) + _v61_dt = str(_v61_pa.dtype) + _v61_sh = tuple(_v61_pa.shape) + else: + _v61_am = _v61_ax = _v61_l2 = -1.0 + _v61_n = 0 + _v61_dt = 'None' + _v61_sh = () + self.logger.info( + '[MTPShipEnumTrace-v61] stage=ENTER rank=%d ' + 'name=%s collect=%s mtp_param_is_none=%s ' + 'numel=%d dtype=%s shape=%s ' + 'abs_mean=%.6e abs_max=%.6e l2=%.6e', + int(dist.get_rank()), name, + str(_collect_mtp_for_draft), + str(_mtp_param is None), + _v61_n, _v61_dt, str(_v61_sh), + _v61_am, _v61_ax, _v61_l2, + ) + except Exception as _e_v61_a: + try: + self.logger.info( + '[MTPShipEnumTrace-v61] ENTER failure: %r', + _e_v61_a, + ) + except Exception: + pass if _collect_mtp_for_draft and _mtp_param is not None: _mtp_model_name = self.hf_config.model_type _prev_count = len(mtp_hf_tensors) @@ -6339,6 +6460,47 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: fp8_direct_convert=self.fp8_direct_convert, ) ) + # [MTPShipEnumTrace-v61] EXIT — log expanded HF names + # and per-tensor bytes added by convert_to_hf for this + # one mcore param. + try: + _v61_added = mtp_hf_tensors[_prev_count:] + for _v61_i, (_v61_hn, _v61_ht) in enumerate(_v61_added): + try: + _v61_hf = _v61_ht.detach().float() + _v61_ham = float(_v61_hf.abs().mean().item()) + _v61_hax = float(_v61_hf.abs().max().item()) + _v61_hl2 = float(_v61_hf.norm().item()) + _v61_hfirst = [ + float(x) for x in + _v61_hf.reshape(-1)[:8].tolist() + ] + except Exception: + _v61_ham = _v61_hax = _v61_hl2 = -1.0 + _v61_hfirst = [] + self.logger.info( + '[MTPShipEnumTrace-v61] stage=EXIT rank=%d ' + 'mcore=%s hf_idx=%d hf_name=%s ' + 'hf_dtype=%s hf_shape=%s hf_numel=%d ' + 'hf_bytes=%d abs_mean=%.6e abs_max=%.6e ' + 'l2=%.6e first8=%s', + int(dist.get_rank()), name, + _v61_i, _v61_hn, + str(_v61_ht.dtype), + str(tuple(_v61_ht.shape)), + int(_v61_ht.numel()), + int(_v61_ht.numel() * _v61_ht.element_size()), + _v61_ham, _v61_hax, _v61_hl2, + str(_v61_hfirst), + ) + except Exception as _e_v61_b: + try: + self.logger.info( + '[MTPShipEnumTrace-v61] EXIT failure: %r', + _e_v61_b, + ) + except Exception: + pass # [MTPBf16UpcastBroadcast-v24] Upcast bf16->fp32 # before serialize so sub-ULP deltas are not # rounded on the wire (default 1). @@ -7559,6 +7721,75 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: f"MTP draft model weights will NOT be updated!" ) + # [MTPShipFinalSummary-v61] one-shot definitive ship summary right + # after the MTP loop completes, BEFORE serialize/send. Unlike the + # per-bucket-flush v56 ship_summary (which can fire 13+ times with + # n_mtp_shipped=0 because the bucket flush happens DURING the MTP + # collection loop), this one fires exactly once per ship and shows + # the actual MTP wire payload list contents. + try: + if _collect_mtp_for_draft: + _v61f_ver = getattr(meta, 'version', 'NA') + _v61f_n = len(mtp_hf_tensors) + _v61f_total_bytes = sum( + int(t.numel() * t.element_size()) + for _, t in mtp_hf_tensors + ) + _v61f_first = mtp_hf_tensors[0][0] if _v61f_n > 0 else None + _v61f_names = [n for n, _ in mtp_hf_tensors] + self.logger.info( + '[MTPShipFinalSummary-v61] rank=%d version=%s ' + 'n_mtp_shipped=%d total_bytes=%d first=%s ' + 'names=%s', + int(dist.get_rank()), str(_v61f_ver), _v61f_n, + _v61f_total_bytes, str(_v61f_first), str(_v61f_names), + ) + # Cross-version delta on a sentinel HF tensor. + if _v61f_n > 0: + if not hasattr(self, '_v61_prev_ship_first8'): + self._v61_prev_ship_first8 = {} + for _v61f_n2, _v61f_t in mtp_hf_tensors: + try: + _v61f_f = _v61f_t.detach().float().reshape(-1) + _v61f_first8 = [ + float(x) for x in _v61f_f[:8].tolist() + ] + _v61f_l2 = float(_v61f_f.norm().item()) + _v61f_prev = self._v61_prev_ship_first8.get( + _v61f_n2, + ) + self._v61_prev_ship_first8[_v61f_n2] = ( + _v61f_first8, _v61f_l2, + ) + if _v61f_prev is not None: + _v61f_pf, _v61f_pl2 = _v61f_prev + _v61f_d8 = [ + (a - b) for a, b in zip( + _v61f_first8, _v61f_pf, + ) + ] + _v61f_dl2 = abs(_v61f_l2 - _v61f_pl2) + else: + _v61f_d8 = [] + _v61f_dl2 = -1.0 + self.logger.info( + '[MTPShipDelta-v61] rank=%d version=%s ' + 'name=%s l2=%.6e d_l2=%.6e first8=%s ' + 'd_first8=%s', + int(dist.get_rank()), str(_v61f_ver), + _v61f_n2, _v61f_l2, _v61f_dl2, + str(_v61f_first8), + str(_v61f_d8), + ) + except Exception: + continue + except Exception as _e_v61f: + try: + self.logger.info( + '[MTPShipFinalSummary-v61] failure: %r', _e_v61f, + ) + except Exception: + pass if _collect_mtp_for_draft and mtp_hf_tensors and dist.get_rank() == 0: try: tp_size = ( diff --git a/areal/models/mcore/hf_load.py b/areal/models/mcore/hf_load.py index 3048e873e2..a58adbb581 100644 --- a/areal/models/mcore/hf_load.py +++ b/areal/models/mcore/hf_load.py @@ -217,6 +217,17 @@ def _weight_to_mcore_tp( mcore_param_shape, hf_weights_safe_slice, tp_rank, tp_size ) if not isinstance(res, FP8BlockwiseTensorHelper): + # [MTPLoadHashAudit-v61] capture pre-swap stats + try: + _v61l_pre = res.detach().float() + _v61l_pre_first8 = [ + float(x) for x in _v61l_pre.reshape(-1)[:8].tolist() + ] + _v61l_pre_am = float(_v61l_pre.abs().mean().item()) + _v61l_pre_ax = float(_v61l_pre.abs().max().item()) + except Exception: + _v61l_pre_first8 = [] + _v61l_pre_am = _v61l_pre_ax = -1.0 first_half, second_half = res.chunk(2, dim=1) res = torch.cat([second_half, first_half], dim=1) logger.info( @@ -224,6 +235,24 @@ def _weight_to_mcore_tp( f"{mcore_weights_name}, shape={tuple(res.shape)}, " f"tp_rank={tp_rank}, tp_size={tp_size}" ) + try: + _v61l_post = res.detach().float() + _v61l_post_first8 = [ + float(x) for x in _v61l_post.reshape(-1)[:8].tolist() + ] + logger.info( + "[MTPLoadHashAudit-v61] eh_proj swap pre_first8=%s " + "post_first8=%s pre_abs_mean=%.6e pre_abs_max=%.6e " + "post_abs_mean=%.6e post_abs_max=%.6e tp_rank=%d " + "tp_size=%d", + str(_v61l_pre_first8), str(_v61l_post_first8), + _v61l_pre_am, _v61l_pre_ax, + float(_v61l_post.abs().mean().item()), + float(_v61l_post.abs().max().item()), + tp_rank, tp_size, + ) + except Exception: + pass else: res = _slice_generic_weight( mcore_param_shape, hf_weights_safe_slice, tp_rank, tp_size @@ -231,6 +260,28 @@ def _weight_to_mcore_tp( if dtype is not None and not isinstance(res, FP8BlockwiseTensorHelper): res = res.to(dtype) + # [MTPLoadHashAudit-v61] for any mtp-bearing mcore name, log the final + # post-conversion stats so we can correlate with MTPPreScan / MTPSrcHash. + try: + if isinstance(res, torch.Tensor) and ('mtp' in mcore_weights_name): + _v61la_f = res.detach().float() + _v61la_first8 = [ + float(x) for x in _v61la_f.reshape(-1)[:8].tolist() + ] + logger.info( + "[MTPLoadHashAudit-v61] post_convert mcore=%s dtype=%s " + "shape=%s abs_mean=%.6e abs_max=%.6e l2=%.6e first8=%s " + "tp_rank=%d tp_size=%d", + mcore_weights_name, str(res.dtype), + str(tuple(res.shape)), + float(_v61la_f.abs().mean().item()), + float(_v61la_f.abs().max().item()), + float(_v61la_f.norm().item()), + str(_v61la_first8), + tp_rank, tp_size, + ) + except Exception: + pass return res From 23f863ea3155df466a9445318835331b4023f0e5 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 14:35:39 +0800 Subject: [PATCH 134/140] feat: log --- areal/engine/megatron_engine.py | 262 ++++++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ddac1d0c94..5f7aba34da 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -3065,6 +3065,128 @@ def _patched_postprocess( embedding=self_model.embedding, **(extra_block_kwargs or {}), ) + # [MTPModelStructAudit-v62] one-shot: + # confirm self_model.mtp IS wired + # and list its layer-0 sub-module + # parameter names + shapes so we + # can cross-check against mcore + # ship list on the NEXT run. + try: + if not getattr(_engine_ref, + '_v62_struct_logged', False): + _v62_mtp_mod = getattr( + self_model, 'mtp', None) + _v62_dec_has_mtp = hasattr( + getattr(self_model, + 'decoder', object()), + 'mtp_layers') + _v62_mtp_proc = getattr( + self_model, 'mtp_process', None) + _v62_names = [] + if _v62_mtp_mod is not None: + try: + _v62_L0 = _v62_mtp_mod.layers[0] + for _v62_pn, _v62_pp in ( + _v62_L0.named_parameters()): + _v62_names.append( + (_v62_pn, + tuple(_v62_pp.shape), + str(_v62_pp.dtype))) + except Exception: + pass + _logger.info( + '[MTPModelStructAudit-v62] ' + 'self_model.mtp=%s ' + 'mtp_process=%s ' + 'mtp_in_postprocess_arg=%s ' + 'decoder.mtp_layers?=%s ' + 'layer0_params=%s', + type(_v62_mtp_mod).__name__ + if _v62_mtp_mod is not None else 'None', + _v62_mtp_proc, + bool(mtp_in_postprocess), + _v62_dec_has_mtp, + _v62_names[:32], + ) + _engine_ref._v62_struct_logged = True + except Exception as _e_v62_s: + try: + _logger.info( + '[MTPModelStructAudit-v62] ' + 'failure: %r', _e_v62_s) + except Exception: + pass + + # [MTPInputIdsAudit-v62] log input_ids / + # labels / hidden_states shape BEFORE the + # chunk so we can verify whether the + # decoder really produced 1+mtp_num_layers + # concatenated seq_len chunks (the signal + # that MTP block ran and the shift-by-1 + # label alignment is correct). + try: + _v62_gs = getattr( + _engine_ref, '_global_step', 0) + if (_mtp_diag_mb_counter[0] == 0 + and (_v62_gs <= 3 + or _v62_gs % 100 == 0)): + try: + _v62_iid_sh = ( + tuple(input_ids.shape) + if input_ids is not None + else None) + _v62_iid_f8 = ( + [int(x) for x in + input_ids.reshape(-1)[:8].tolist()] + if input_ids is not None else []) + except Exception: + _v62_iid_sh = None + _v62_iid_f8 = [] + try: + _v62_lb_sh = ( + tuple(labels.shape) + if labels is not None + else None) + _v62_lb_f8 = ( + [int(x) for x in + labels.reshape(-1)[:8].tolist()] + if labels is not None else []) + except Exception: + _v62_lb_sh = None + _v62_lb_f8 = [] + try: + _v62_hs_sh = tuple( + hidden_states.shape) + _v62_hs_f8 = [ + float(x) for x in + hidden_states.detach() + .float().reshape(-1)[:8] + .tolist()] + except Exception: + _v62_hs_sh = None + _v62_hs_f8 = [] + _logger.info( + '[MTPInputIdsAudit-v62] ' + 'step=%d mtp_num_layers=%s ' + 'input_ids.shape=%s ' + 'input_ids.first8=%s ' + 'labels.shape=%s ' + 'labels.first8=%s ' + 'hidden_states.shape=%s ' + 'hidden_states.first8=%s', + _v62_gs, + self_model.config.mtp_num_layers, + _v62_iid_sh, _v62_iid_f8, + _v62_lb_sh, _v62_lb_f8, + _v62_hs_sh, _v62_hs_f8, + ) + except Exception as _e_v62_i: + try: + _logger.info( + '[MTPInputIdsAudit-v62] ' + 'failure: %r', _e_v62_i) + except Exception: + pass if not self_model.post_process: return hidden_states @@ -3076,6 +3198,52 @@ def _patched_postprocess( 1 + self_model.config.mtp_num_layers, dim=0, ) + # [MTPHsChunkAudit-v62] per-chunk + # stats: if the MTP block really ran + # inside `self.mtp(...)`, chunks + # should be DISTINCT; if they are + # identical the MTP block was NOT + # exercised and the decline comes + # from the main backbone only. + try: + _v62_cgs = getattr( + _engine_ref, '_global_step', 0) + if (_mtp_diag_mb_counter[0] == 0 + and (_v62_cgs <= 3 + or _v62_cgs % 100 == 0)): + for _v62_ci, _v62_ch in enumerate( + hidden_states_list): + try: + _v62_chf = _v62_ch.detach().float() + _v62_l2 = float(_v62_chf.norm().item()) + _v62_am = float( + _v62_chf.abs().mean().item()) + _v62_ax = float( + _v62_chf.abs().max().item()) + _v62_f8 = [ + float(x) for x in + _v62_chf.reshape(-1)[:8].tolist()] + _logger.info( + '[MTPHsChunkAudit-v62] ' + 'step=%d chunk=%d/%d ' + 'shape=%s abs_mean=%.6e ' + 'abs_max=%.6e l2=%.6e ' + 'first8=%s', + _v62_cgs, _v62_ci, + len(hidden_states_list), + tuple(_v62_ch.shape), + _v62_am, _v62_ax, _v62_l2, + _v62_f8, + ) + except Exception: + continue + except Exception as _e_v62_c: + try: + _logger.info( + '[MTPHsChunkAudit-v62] ' + 'failure: %r', _e_v62_c) + except Exception: + pass hidden_states = hidden_states_list[0] if loss_mask is None: loss_mask = torch.ones_like(mtp_labels) @@ -3220,6 +3388,52 @@ def _patched_postprocess( mtp_loss = self_model.compute_language_model_loss( mtp_labels, mtp_logits ) + # [MTPLossPerLayerAudit-v62] + # break down aggregated + # mtp_loss per mtp layer so + # we can see whether layer-0 + # is learning (CE decreasing) + # even while spec_accept_rate + # declines. + try: + _v62_lgs = getattr( + _engine_ref, '_global_step', 0) + if (_mtp_diag_mb_counter[0] == 0 + and (_v62_lgs <= 3 + or _v62_lgs % 100 == 0)): + try: + _v62_ml_sum = float( + mtp_loss.detach() + .float().sum().item()) + except Exception: + _v62_ml_sum = float('nan') + try: + _v62_nt = int( + num_tokens.detach() + .sum().item()) + except Exception: + _v62_nt = -1 + _v62_mean = ( + _v62_ml_sum / _v62_nt + if _v62_nt > 0 else float('nan')) + _logger.info( + '[MTPLossPerLayerAudit-v62] ' + 'step=%d mtp_layer=%d ' + 'loss_sum=%.4f ' + 'num_tokens=%d ' + 'loss_mean=%.4f', + _v62_lgs, + mtp_layer_number, + _v62_ml_sum, _v62_nt, + _v62_mean, + ) + except Exception as _e_v62_l: + try: + _logger.info( + '[MTPLossPerLayerAudit-v62] ' + 'failure: %r', _e_v62_l) + except Exception: + pass mtp_loss = loss_mask * mtp_loss try: _d05_step = getattr( @@ -7790,6 +8004,54 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except Exception: pass + # [MTPShipHashAudit-v62] rank-0 full list dump with hash + # so that next round we can cross-check exactly which + # HF-named bytes were shipped versus what the draft + # engine received / applied. This is independent of + # the existing v54/v56/v61 summaries; focuses only on + # deterministic content-hash identity of each tensor. + if (_collect_mtp_for_draft and mtp_hf_tensors + and dist.get_rank() == 0): + try: + import hashlib as _v62_hashlib + for _v62_hn, _v62_ht in mtp_hf_tensors: + try: + _v62_cpu = ( + _v62_ht.detach().contiguous() + .cpu().view(torch.uint8)) + _v62_nb = _v62_cpu.numel() + _v62_h = _v62_hashlib.sha256( + _v62_cpu.numpy().tobytes()).hexdigest()[:16] + _v62_f8 = [ + float(x) for x in + _v62_ht.detach().float() + .reshape(-1)[:8].tolist()] + self.logger.info( + '[MTPShipHashAudit-v62] version=%s ' + 'hf_name=%s dtype=%s shape=%s ' + 'bytes=%d sha256_16=%s first8=%s', + getattr(meta, 'version', None), + _v62_hn, + str(_v62_ht.dtype), + tuple(_v62_ht.shape), + _v62_nb, _v62_h, _v62_f8, + ) + except Exception as _e_v62_t: + try: + self.logger.info( + '[MTPShipHashAudit-v62] ' + 'tensor %s failure: %r', + _v62_hn, _e_v62_t) + except Exception: + pass + except Exception as _e_v62_out: + try: + self.logger.info( + '[MTPShipHashAudit-v62] outer failure: %r', + _e_v62_out) + except Exception: + pass + if _collect_mtp_for_draft and mtp_hf_tensors and dist.get_rank() == 0: try: tp_size = ( From cd6fa365793faa65d89ada28359b20ac1b8a9cdc Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 15:49:14 +0800 Subject: [PATCH 135/140] feat(megatron): log1 --- areal/engine/megatron_engine.py | 240 ++++++++++++++++++++++++ areal/engine/megatron_utils/megatron.py | 65 +++++++ 2 files changed, 305 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 5f7aba34da..3f3ea753d0 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -6664,6 +6664,126 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except Exception: pass + # [MTPShipPostAGAudit-v63] Right BEFORE convert_to_hf, + # log the post-all_gather _mtp_param tensor (full + # gathered shape, sha256_16, first/last 8). This is + # the EXACT mcore-side payload that goes into the + # HF mapping. Comparing this hash across versions + # tells us whether ship-time TP all_gather is + # producing identical-byte tensors per version + # (would explain stalled draft despite training). + try: + import hashlib as _v63_pag_hash + _v63_pag_t = _mtp_param.detach().contiguous() + _v63_pag_bytes = ( + _v63_pag_t.float().cpu().numpy().tobytes() + ) + _v63_pag_h = _v63_pag_hash.sha256( + _v63_pag_bytes).hexdigest()[:16] + _v63_pag_first = [ + float(x) for x in + _v63_pag_t.reshape(-1)[:8].float() + .cpu().tolist() + ] + _v63_pag_last = [ + float(x) for x in + _v63_pag_t.reshape(-1)[-8:].float() + .cpu().tolist() + ] + try: + _v63_pag_ver = int(self.get_version()) + except Exception: + _v63_pag_ver = -1 + self.logger.info( + "[MTPShipPostAGAudit-v63] version=%d " + "name=%s shape=%s dtype=%s " + "sha256_16=%s first8=%s last8=%s " + "abs_mean=%.6e abs_max=%.6e l2=%.6e", + _v63_pag_ver, name, + tuple(_v63_pag_t.shape), + str(_v63_pag_t.dtype), + _v63_pag_h, + str(_v63_pag_first), str(_v63_pag_last), + float(_v63_pag_t.float().abs().mean().item()), + float(_v63_pag_t.float().abs().max().item()), + float(_v63_pag_t.float().norm().item()), + ) + except Exception as _e_v63_pag: + try: + self.logger.info( + "[MTPShipPostAGAudit-v63] failure: %r", + _e_v63_pag, + ) + except Exception: + pass + # [MTPMainParamCmpAudit-v63] Compare bf16 model + # param vs fp32 main_param at ship time. If they + # diverge by more than bf16 ULP, stochastic + # rounding desync between training and ship is + # the root cause of post-ship draft regression. + try: + _v63_mp_param_obj = param # original module param + _v63_mp = getattr( + _v63_mp_param_obj, 'main_param', None) + if _v63_mp is not None: + import torch as _v63_torch_mp + _v63_mp_fp32 = _v63_mp.detach().float() + _v63_bf = _v63_mp_param_obj.detach().float() + if _v63_mp_fp32.shape == _v63_bf.shape: + _v63_d = (_v63_mp_fp32 - _v63_bf).abs() + _v63_d_max = float(_v63_d.max().item()) + _v63_d_mean = float(_v63_d.mean().item()) + _v63_amax = float( + _v63_mp_fp32.abs().max().item()) + _v63_ulp = -1.0 + if _v63_amax > 0: + import math as _v63_math + _v63_e = _v63_math.floor( + _v63_math.log2(_v63_amax)) + _v63_ulp = 2.0 ** (_v63_e - 7) + _v63_dratio = ( + _v63_d_max / _v63_ulp + if _v63_ulp > 0 else -1.0 + ) + self.logger.info( + "[MTPMainParamCmpAudit-v63] " + "name=%s shape=%s " + "fp32_main_param_sum=%.6e " + "bf16_model_param_sum=%.6e " + "delta_abs_max=%.6e " + "delta_abs_mean=%.6e " + "bf16_ulp=%.6e " + "delta_to_ulp_ratio=%.4f", + name, tuple(_v63_mp_fp32.shape), + float(_v63_mp_fp32.sum().item()), + float(_v63_bf.sum().item()), + _v63_d_max, _v63_d_mean, + _v63_ulp, _v63_dratio, + ) + else: + self.logger.info( + "[MTPMainParamCmpAudit-v63] " + "shape mismatch name=%s " + "main_param=%s bf16=%s", + name, + tuple(_v63_mp_fp32.shape), + tuple(_v63_bf.shape), + ) + else: + self.logger.info( + "[MTPMainParamCmpAudit-v63] " + "name=%s main_param=None " + "(no fp32 master on this rank)", + name, + ) + except Exception as _e_v63_mp: + try: + self.logger.info( + "[MTPMainParamCmpAudit-v63] failure: %r", + _e_v63_mp, + ) + except Exception: + pass mtp_hf_tensors.extend( convert_to_hf( self.tf_config, @@ -8051,6 +8171,126 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: _e_v62_out) except Exception: pass + # [MTPShipKeyOverlap-v63] After all MTP HF tensors are collected + # AND the main bucket converted_named_tensors is finalised, + # cross-check whether MTP HF names overlap with main-bucket + # HF names being shipped in the SAME wave. sglang's EAGLE + # draft model shares some backbone weights (embedding, + # output_layer) with the target model; if MTP-collected tensors + # collide with main-bucket HF names, one would overwrite the + # other in unpredictable order, causing post-ship draft + # regression that matches what spec_v1.log.5 shows. + try: + if (_collect_mtp_for_draft + and mtp_hf_tensors + and dist.get_rank() == 0): + _v63_mtp_names = set(n for n, _ in mtp_hf_tensors) + _v63_main_names = set() + try: + _v63_main_names = set( + n for n, _ in (converted_named_tensors or []) + ) + except Exception: + pass + _v63_overlap = sorted( + _v63_mtp_names & _v63_main_names) + self.logger.info( + "[MTPShipKeyOverlap-v63] version=%s " + "n_mtp=%d n_main=%d n_overlap=%d " + "overlap_keys=%s " + "mtp_only_sample=%s main_only_sample=%s", + str(getattr(meta, 'version', 'NA')), + len(_v63_mtp_names), len(_v63_main_names), + len(_v63_overlap), + str(_v63_overlap[:16]), + str(sorted(_v63_mtp_names - _v63_main_names)[:8]), + str(sorted(_v63_main_names - _v63_mtp_names)[:8]), + ) + if _v63_overlap: + self.logger.warning( + "[MTPShipKeyOverlap-v63] OVERLAP DETECTED " + "version=%s — %d HF names ship in BOTH the " + "main bucket AND the MTP wire. SGLang receives " + "BOTH writes for the same key; last-writer " + "wins and may overwrite the MTP-trained value " + "with the main-model value (or vice versa). " + "Sample: %s", + str(getattr(meta, 'version', 'NA')), + len(_v63_overlap), str(_v63_overlap[:8]), + ) + except Exception as _e_v63_ko: + try: + self.logger.info( + "[MTPShipKeyOverlap-v63] failure: %r", _e_v63_ko, + ) + except Exception: + pass + + # [MTPDraftReadbackV4-v63] Probe alternative sglang readback + # endpoints to capture what the draft model ACTUALLY has + # post-ship. v32's /get_weights_by_name path is blocked for + # MiMo; this v63 probe attempts /update_weights_from_tensor + # echo paths and a generic /get_internal_state fallback so + # that next round we can correlate ship-time hash with + # draft-side hash even when one channel is blocked. + try: + import os as _v63_os_rb + if (_collect_mtp_for_draft + and mtp_hf_tensors + and dist.get_rank() == 0 + and _v63_os_rb.environ.get( + 'AREAL_MTP_DRAFT_READBACK_V4', '1') == '1'): + _v63_rb_engine = getattr( + self, 'rollout_engine', None) + _v63_rb_endpoints = [ + 'get_weights_by_name', + 'get_internal_state', + 'flush_cache', + ] + for _v63_ep in _v63_rb_endpoints: + _v63_fn = getattr( + _v63_rb_engine, _v63_ep, None) + self.logger.info( + "[MTPDraftReadbackV4-v63] version=%s " + "endpoint=%s callable=%s", + str(getattr(meta, 'version', 'NA')), + _v63_ep, str(callable(_v63_fn)), + ) + if callable(_v63_fn): + try: + if _v63_ep == 'get_weights_by_name': + _v63_rb_names = set(n for n, _ in mtp_hf_tensors) + _v63_target_name = next( + iter(_v63_rb_names), None) + if _v63_target_name is not None: + _v63_rb_res = _v63_fn( + _v63_target_name) + else: + _v63_rb_res = None + else: + _v63_rb_res = _v63_fn() + self.logger.info( + "[MTPDraftReadbackV4-v63] " + "endpoint=%s status=OK " + "result_type=%s " + "result_repr_head=%.200s", + _v63_ep, type(_v63_rb_res).__name__, + repr(_v63_rb_res), + ) + except Exception as _e_v63_ep: + self.logger.info( + "[MTPDraftReadbackV4-v63] " + "endpoint=%s status=FAIL err=%r", + _v63_ep, _e_v63_ep, + ) + except Exception as _e_v63_rb: + try: + self.logger.info( + "[MTPDraftReadbackV4-v63] outer failure: %r", + _e_v63_rb, + ) + except Exception: + pass if _collect_mtp_for_draft and mtp_hf_tensors and dist.get_rank() == 0: try: diff --git a/areal/engine/megatron_utils/megatron.py b/areal/engine/megatron_utils/megatron.py index 93c7b9c43b..59471e7911 100644 --- a/areal/engine/megatron_utils/megatron.py +++ b/areal/engine/megatron_utils/megatron.py @@ -987,8 +987,73 @@ def _convert_mimo_mtp_param( # MiMo-specific: swap column halves for eh_proj weight if component == "eh_proj.weight": + # [MTPShipPreSwapAudit-v63] BEFORE the column-half swap, log tensor + # stats so we can compare with slime's _mcore_to_hf_format output. + # Mismatch here would prove ship-time swap differs from load-time swap. + try: + import logging as _v63_log_mod + import hashlib as _v63_hash_mod + _v63_lg = _v63_log_mod.getLogger("AReaL") + _v63_pre_t = param.detach().contiguous() + _v63_pre_bytes = _v63_pre_t.float().cpu().numpy().tobytes() + _v63_pre_h = _v63_hash_mod.sha256(_v63_pre_bytes).hexdigest()[:16] + _v63_pre_first = [ + float(x) for x in + _v63_pre_t.reshape(-1)[:8].float().cpu().tolist() + ] + _v63_pre_last = [ + float(x) for x in + _v63_pre_t.reshape(-1)[-8:].float().cpu().tolist() + ] + _v63_lg.info( + "[MTPShipPreSwapAudit-v63] stage=PRE layer=%s component=%s " + "shape=%s dtype=%s sha256_16=%s first8=%s last8=%s", + str(layer_idx), component, + tuple(param.shape), str(param.dtype), + _v63_pre_h, str(_v63_pre_first), str(_v63_pre_last), + ) + except Exception as _e_v63_pre: + try: + import logging as _v63_log_mod_b + _v63_log_mod_b.getLogger("AReaL").info( + "[MTPShipPreSwapAudit-v63] PRE failure: %r", _e_v63_pre, + ) + except Exception: + pass first_half, second_half = param.chunk(2, dim=1) param = torch.cat([second_half, first_half], dim=1) + # [MTPShipPreSwapAudit-v63] AFTER swap, log post-swap stats. Compare + # pre vs post sha256_16 to verify the swap actually moved bytes. + try: + import logging as _v63_log_post + import hashlib as _v63_hash_post + _v63_lp = _v63_log_post.getLogger("AReaL") + _v63_post_t = param.detach().contiguous() + _v63_post_bytes = _v63_post_t.float().cpu().numpy().tobytes() + _v63_post_h = _v63_hash_post.sha256(_v63_post_bytes).hexdigest()[:16] + _v63_post_first = [ + float(x) for x in + _v63_post_t.reshape(-1)[:8].float().cpu().tolist() + ] + _v63_post_last = [ + float(x) for x in + _v63_post_t.reshape(-1)[-8:].float().cpu().tolist() + ] + _v63_lp.info( + "[MTPShipPreSwapAudit-v63] stage=POST layer=%s component=%s " + "shape=%s sha256_16=%s first8=%s last8=%s", + str(layer_idx), component, + tuple(param.shape), _v63_post_h, + str(_v63_post_first), str(_v63_post_last), + ) + except Exception as _e_v63_post: + try: + import logging as _v63_log_post_b + _v63_log_post_b.getLogger("AReaL").info( + "[MTPShipPreSwapAudit-v63] POST failure: %r", _e_v63_post, + ) + except Exception: + pass # Check direct mappings first if component in direct_mappings: From 04e2c2052059c6724de90d38d400dc54009e96a5 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 17:26:27 +0800 Subject: [PATCH 136/140] feat: log --- areal/engine/megatron_engine.py | 44 +++++++++ areal/engine/megatron_utils/megatron.py | 76 +++++++++++++++ areal/infra/remote_inf_engine.py | 117 ++++++++++++++++++++++++ 3 files changed, 237 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 3f3ea753d0..3e5b96af3d 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -6784,6 +6784,50 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except Exception: pass + # [MTPShipPostAGAudit-v64] Right BEFORE convert_to_hf, + # log the post-all_gather _mtp_param tensor. This + # is the EXACT mcore-side payload that flows into + # the HF mapping. Compared with PRE/POST swap + # audits and with WireBytes audit, this nails the + # location of any divergence. + try: + import hashlib as _v64_pag_hash + _v64_pag_t = _mtp_param.detach().contiguous() + _v64_pag_b = ( + _v64_pag_t.float().cpu().numpy().tobytes() + ) + _v64_pag_h = _v64_pag_hash.sha256( + _v64_pag_b).hexdigest()[:16] + _v64_pag_f8 = [ + float(x) for x in + _v64_pag_t.reshape(-1)[:8].float() + .cpu().tolist() + ] + try: + _v64_pag_ver = int(self.get_version()) + except Exception: + _v64_pag_ver = -1 + self.logger.info( + "[MTPShipPostAGAudit-v64] version=%d " + "name=%s shape=%s dtype=%s " + "sha256_16=%s first8=%s " + "abs_mean=%.6e abs_max=%.6e l2=%.6e", + _v64_pag_ver, name, + tuple(_v64_pag_t.shape), + str(_v64_pag_t.dtype), + _v64_pag_h, str(_v64_pag_f8), + float(_v64_pag_t.float().abs().mean().item()), + float(_v64_pag_t.float().abs().max().item()), + float(_v64_pag_t.float().norm().item()), + ) + except Exception as _e_v64_pag: + try: + self.logger.info( + "[MTPShipPostAGAudit-v64] failure: %r", + _e_v64_pag, + ) + except Exception: + pass mtp_hf_tensors.extend( convert_to_hf( self.tf_config, diff --git a/areal/engine/megatron_utils/megatron.py b/areal/engine/megatron_utils/megatron.py index 59471e7911..c31ecd7087 100644 --- a/areal/engine/megatron_utils/megatron.py +++ b/areal/engine/megatron_utils/megatron.py @@ -987,6 +987,45 @@ def _convert_mimo_mtp_param( # MiMo-specific: swap column halves for eh_proj weight if component == "eh_proj.weight": + # [MTPShipSwapAudit-v64] PRE — log tensor stats just BEFORE the + # eh_proj column-half swap. Uses module-level logger to ensure + # the record actually reaches stdout (same logger emits + # [MTPLoad] / [MTPLoadHashAudit-v61] verified in + # spec_v1.log.7). This audit fires iff the swap branch is + # actually entered at ship time — answers whether the wire + # divergence comes from this path or another path. + try: + import logging as _v64_pre_log + import hashlib as _v64_pre_hash + _v64_pre_lg = _v64_pre_log.getLogger(__name__) + _v64_pre_t = param.detach().contiguous() + _v64_pre_b = _v64_pre_t.float().cpu().numpy().tobytes() + _v64_pre_h = _v64_pre_hash.sha256(_v64_pre_b).hexdigest()[:16] + _v64_pre_f8 = [ + float(x) for x in + _v64_pre_t.reshape(-1)[:8].float().cpu().tolist() + ] + _v64_pre_l8 = [ + float(x) for x in + _v64_pre_t.reshape(-1)[-8:].float().cpu().tolist() + ] + _v64_pre_lg.info( + "[MTPShipSwapAudit-v64] stage=PRE layer=%s " + "component=%s shape=%s dtype=%s sha256_16=%s " + "first8=%s last8=%s", + str(layer_idx), component, + tuple(param.shape), str(param.dtype), + _v64_pre_h, str(_v64_pre_f8), str(_v64_pre_l8), + ) + except Exception as _e_v64_pre: + try: + import logging as _v64_pre_logb + _v64_pre_logb.getLogger(__name__).warning( + "[MTPShipSwapAudit-v64] PRE failure: %r", + _e_v64_pre, + ) + except Exception: + pass # [MTPShipPreSwapAudit-v63] BEFORE the column-half swap, log tensor # stats so we can compare with slime's _mcore_to_hf_format output. # Mismatch here would prove ship-time swap differs from load-time swap. @@ -1022,6 +1061,43 @@ def _convert_mimo_mtp_param( pass first_half, second_half = param.chunk(2, dim=1) param = torch.cat([second_half, first_half], dim=1) + # [MTPShipSwapAudit-v64] POST — log tensor stats AFTER the + # column-half swap. Comparing PRE vs POST sha256_16 confirms + # bytes actually moved. Comparing POST sha256_16 with the + # next-stage [MTPWireBytesAudit-v64] tells us whether anything + # mutates the tensor between this function and the HTTP send. + try: + import logging as _v64_post_log + import hashlib as _v64_post_hash + _v64_post_lg = _v64_post_log.getLogger(__name__) + _v64_post_t = param.detach().contiguous() + _v64_post_b = _v64_post_t.float().cpu().numpy().tobytes() + _v64_post_h = _v64_post_hash.sha256(_v64_post_b).hexdigest()[:16] + _v64_post_f8 = [ + float(x) for x in + _v64_post_t.reshape(-1)[:8].float().cpu().tolist() + ] + _v64_post_l8 = [ + float(x) for x in + _v64_post_t.reshape(-1)[-8:].float().cpu().tolist() + ] + _v64_post_lg.info( + "[MTPShipSwapAudit-v64] stage=POST layer=%s " + "component=%s shape=%s dtype=%s sha256_16=%s " + "first8=%s last8=%s", + str(layer_idx), component, + tuple(param.shape), str(param.dtype), + _v64_post_h, str(_v64_post_f8), str(_v64_post_l8), + ) + except Exception as _e_v64_post: + try: + import logging as _v64_post_logb + _v64_post_logb.getLogger(__name__).warning( + "[MTPShipSwapAudit-v64] POST failure: %r", + _e_v64_post, + ) + except Exception: + pass # [MTPShipPreSwapAudit-v63] AFTER swap, log post-swap stats. Compare # pre vs post sha256_16 to verify the swap actually moved bytes. try: diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 22150585bb..54705700b7 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -1063,6 +1063,66 @@ def update_weights_from_tensor_serialized( f"payload_keys={_payload_keys}, n_serialized_tensors={_n_tensors}, " f"addresses={self.addresses}" ) + # [MTPWireBytesAudit-v64] Hash the actual serialized payload + # for the critical MTP HF names just BEFORE the HTTP send to + # sglang. If this hash does not match the [MTPShipSwapAudit-v64] + # POST hash for input_proj.weight, mutation is happening + # between megatron's convert_to_hf and the wire send. If they + # match but [MTPDraftSglangProbe-v64] returns a different + # value, divergence is on the sglang side. + try: + import hashlib as _v64_w_hash + _v64_w_named = serialized_payload.get( + "serialized_named_tensors", [] + ) if isinstance(serialized_payload, dict) else [] + _v64_w_targets = ( + "model.mtp_layers.0.input_proj.weight", + "model.mtp_layers.0.token_layernorm.weight", + "model.mtp_layers.0.hidden_layernorm.weight", + "model.mtp_layers.0.final_layernorm.weight", + ) + _v64_w_count = 0 + for _v64_w_item in _v64_w_named: + _v64_w_name = None + _v64_w_blob = None + if ( + isinstance(_v64_w_item, (list, tuple)) + and len(_v64_w_item) >= 2 + ): + _v64_w_name = _v64_w_item[0] + _v64_w_blob = _v64_w_item[1] + elif isinstance(_v64_w_item, dict): + _v64_w_name = _v64_w_item.get("name") + _v64_w_blob = _v64_w_item.get( + "tensor", _v64_w_item.get("data")) + if _v64_w_name is None or _v64_w_name not in _v64_w_targets: + continue + if isinstance(_v64_w_blob, str): + _v64_w_raw = _v64_w_blob.encode("utf-8") + elif isinstance(_v64_w_blob, (bytes, bytearray)): + _v64_w_raw = bytes(_v64_w_blob) + else: + _v64_w_raw = repr(_v64_w_blob).encode("utf-8") + _v64_w_h = _v64_w_hash.sha256(_v64_w_raw).hexdigest()[:16] + logger.info( + "[MTPWireBytesAudit-v64] hf_name=%s " + "blob_type=%s blob_size=%d sha256_16=%s", + _v64_w_name, type(_v64_w_blob).__name__, + len(_v64_w_raw), _v64_w_h, + ) + _v64_w_count += 1 + logger.info( + "[MTPWireBytesAudit-v64] summary n_hashed=%d " + "n_named_tensors=%d addresses=%s", + _v64_w_count, len(_v64_w_named), self.addresses, + ) + except Exception as _e_v64_w: + try: + logger.warning( + "[MTPWireBytesAudit-v64] failure: %r", _e_v64_w, + ) + except Exception: + pass http_req = HttpRequest( endpoint="/update_weights_from_tensor", payload=serialized_payload, @@ -1077,6 +1137,63 @@ def update_weights_from_tensor_serialized( f"[DiagMTP][Worker] update_weights_from_tensor_serialized " f"COMPLETED in {_time.time() - _t0:.3f}s" ) + # [MTPDraftSglangProbe-v64] After the wire send completes, + # probe the sglang draft via /get_weights_by_name to read + # the bytes the draft model actually holds for our shipped + # MTP names. This is the only way to determine whether + # divergence is wire-side or draft-side (sglang internal + # layout / column-half-swap mismatch with mcore output). + try: + import hashlib as _v64_pr_hash + import requests as _v64_pr_req + _v64_pr_targets = [ + "model.mtp_layers.0.input_proj.weight", + "model.mtp_layers.0.token_layernorm.weight", + ] + for _v64_pr_addr in self.addresses: + for _v64_pr_n in _v64_pr_targets: + _v64_pr_url = ( + f"http://{_v64_pr_addr}/get_weights_by_name" + ) + try: + _v64_pr_resp = _v64_pr_req.post( + _v64_pr_url, + json={ + "name": _v64_pr_n, + "truncate_size": 32, + }, + timeout=10, + ) + _v64_pr_status = int( + _v64_pr_resp.status_code) + _v64_pr_body = _v64_pr_resp.text + _v64_pr_h = _v64_pr_hash.sha256( + _v64_pr_body.encode("utf-8") + ).hexdigest()[:16] + logger.info( + "[MTPDraftSglangProbe-v64] addr=%s " + "name=%s status=%d body_len=%d " + "sha256_16=%s head=%.300s", + _v64_pr_addr, _v64_pr_n, + _v64_pr_status, + len(_v64_pr_body), + _v64_pr_h, _v64_pr_body, + ) + except Exception as _e_v64_pr_inner: + logger.info( + "[MTPDraftSglangProbe-v64] addr=%s " + "name=%s FAIL err=%r", + _v64_pr_addr, _v64_pr_n, + _e_v64_pr_inner, + ) + except Exception as _e_v64_pr: + try: + logger.warning( + "[MTPDraftSglangProbe-v64] outer failure: %r", + _e_v64_pr, + ) + except Exception: + pass except Exception as e: logger.error( f"[DiagMTP][Worker] update_weights_from_tensor_serialized " From 873f6a348ac2e38da97a4496e388af3fbd04e58d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 17:48:44 +0800 Subject: [PATCH 137/140] feat: MTPShipEntryAudit --- areal/engine/megatron_engine.py | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 3e5b96af3d..ac3768d476 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -883,6 +883,42 @@ def prepare_batch( ) def update_weights(self, meta: WeightUpdateMeta): + # [MTPShipEntryAudit-v65] First-line audit: prove the ship + # entry was even reached. log.8 ran ~10min and never hit + # ship; zero v64 records emitted. This audit fires before + # any rollout-connection check so we can distinguish 'ship + # never invoked' from 'ship invoked but failed inside'. + try: + import logging as _v65_log_mod + import time as _v65_time_mod + _v65_lg = _v65_log_mod.getLogger(__name__) + try: + _v65_ver = int(self.get_version()) + except Exception: + _v65_ver = -1 + try: + _v65_meta_type = str(getattr(meta, 'type', '?')) + except Exception: + _v65_meta_type = '?' + try: + _v65_meta_path = str(getattr(meta, 'path', '')) + except Exception: + _v65_meta_path = '' + _v65_lg.info( + "[MTPShipEntryAudit-v65] update_weights ENTER " + "version=%d meta_type=%s meta_path=%s ts=%.3f", + _v65_ver, _v65_meta_type, _v65_meta_path, + _v65_time_mod.time(), + ) + except Exception as _e_v65: + try: + import logging as _v65_log_mod_b + _v65_log_mod_b.getLogger(__name__).warning( + "[MTPShipEntryAudit-v65] entry-audit failure: %r", + _e_v65, + ) + except Exception: + pass self._check_rollout_engine_connected() if meta.type == "xccl": assert self.weight_update_group_initialized From 59af971d4c27d3dc81f69fcd8d9f4aaf612a8b31 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 18:20:59 +0800 Subject: [PATCH 138/140] fix(engine): SigmaDeltaBf16 --- areal/engine/megatron_engine.py | 41 +++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ac3768d476..cdd82f7aab 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -6973,10 +6973,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: # is deterministic and preserves monotonic # sub-ULP trajectories. # - # NOTES - # * slime/verl do not address this. Research of - # https://github.com/THUDM/slime , - # https://github.com/volcengine/verl , SGLang + # NOTES SGLang # v0.5.9 and Megatron-LM core_r0.16.0 confirms # they all ship bf16 round-to-nearest. See # megatron distrib_optimizer.py @@ -6998,6 +6995,42 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: ) except Exception: _sd_on = True + # [MTPSlimeAlign-v66] When slime-align is ON, + # disable v28 SigmaDeltaBf16 entirely. Slime's + # _weight_to_hf_format performs only a shape + # transform + RNE bf16 cast; no residual carry, + # no sub-ULP drift injection. Leaving v28 ON + # under SLIME_ALIGN ships ULP-noisy MTP + # weights, which (per spec_v1.log.9) drives + # AcceptEMA from ~0.32 to ~0.18 within 30s of a + # single ship. Gating v28 here matches the I) + # promise of the SLIME_ALIGN banner: AReaL's + # ship-side mutation surface == slime's. + try: + _v66_slime = ( + _os_v16.environ.get( + 'AREAL_MTP_SLIME_ALIGN', '1' + ) == '1' + ) + except Exception: + _v66_slime = True + if _v66_slime and _sd_on: + _sd_on = False + if not getattr( + self, '_v66_i_logged', False + ): + try: + self.logger.info( + '[MTPSlimeAlign] I) v28 ' + 'SigmaDeltaBf16 DISABLED under ' + 'AREAL_MTP_SLIME_ALIGN=1 (slime ' + 'ships clean RNE bf16; no ' + 'residual carry, no sub-ULP ' + 'dither).' + ) + self._v66_i_logged = True + except Exception: + pass if _sd_on: # [v34] Defensive torch import: v28 SigmaDelta # block references _torch_v16 but the original From 36d72f912c0ed3a1abbcb9955a04614565c119b5 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 9 May 2026 10:12:48 +0800 Subject: [PATCH 139/140] feat(infra): MTPWireBytesFloat --- areal/infra/remote_inf_engine.py | 103 +++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 54705700b7..84c654178d 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -1104,6 +1104,92 @@ def update_weights_from_tensor_serialized( else: _v64_w_raw = repr(_v64_w_blob).encode("utf-8") _v64_w_h = _v64_w_hash.sha256(_v64_w_raw).hexdigest()[:16] + # [MTPWireBytesFloatAudit-v67] decode the + # base64 fp32 payload and log first8/last8 + + # abs_mean/abs_max/l2. Compare against the + # most recent [MTPShipPostAGAudit-v64] for the + # same hf_name to detect any mutation between + # megatron convert_to_hf and the actual wire + # send to sglang. + try: + import base64 as _v67_b64 + import struct as _v67_struct + _v67_payload = None + if isinstance(_v64_w_blob, str): + try: + _v67_payload = _v67_b64.b64decode( + _v64_w_blob, validate=False) + except Exception: + _v67_payload = None + elif isinstance( + _v64_w_blob, (bytes, bytearray) + ): + _v67_payload = bytes(_v64_w_blob) + if _v67_payload is not None and len( + _v67_payload + ) >= 32: + _v67_n = len(_v67_payload) // 4 + _v67_first_n = min(8, _v67_n) + _v67_last_n = min(8, _v67_n) + _v67_first8 = list( + _v67_struct.unpack( + "<%df" % _v67_first_n, + _v67_payload[:4 * _v67_first_n], + ) + ) + _v67_last8 = list( + _v67_struct.unpack( + "<%df" % _v67_last_n, + _v67_payload[-4 * _v67_last_n:], + ) + ) + _v67_sample = min(_v67_n, 65536) + _v67_floats = list( + _v67_struct.unpack( + "<%df" % _v67_sample, + _v67_payload[:4 * _v67_sample], + ) + ) + _v67_abs_mean = ( + sum(abs(_x) for _x in _v67_floats) + / max(1, len(_v67_floats)) + ) + _v67_abs_max = max( + (abs(_x) for _x in _v67_floats), + default=0.0, + ) + _v67_l2_sq = sum( + _x * _x for _x in _v67_floats + ) + _v67_l2 = _v67_l2_sq ** 0.5 + logger.info( + "[MTPWireBytesFloatAudit-v67] " + "hf_name=%s nbytes=%d nfloats=%d " + "first8=%s last8=%s " + "abs_mean=%.6e abs_max=%.6e " + "l2_sample=%.6e sample_n=%d", + _v64_w_name, len(_v67_payload), + _v67_n, _v67_first8, _v67_last8, + _v67_abs_mean, _v67_abs_max, + _v67_l2, _v67_sample, + ) + else: + logger.info( + "[MTPWireBytesFloatAudit-v67] " + "hf_name=%s SKIPPED reason=%s", + _v64_w_name, + "no_payload" if _v67_payload is None + else "payload_too_short", + ) + except Exception as _e_v67_a: + try: + logger.warning( + "[MTPWireBytesFloatAudit-v67] " + "hf_name=%s failure: %r", + _v64_w_name, _e_v67_a, + ) + except Exception: + pass logger.info( "[MTPWireBytesAudit-v64] hf_name=%s " "blob_type=%s blob_size=%d sha256_16=%s", @@ -1146,9 +1232,26 @@ def update_weights_from_tensor_serialized( try: import hashlib as _v64_pr_hash import requests as _v64_pr_req + # [MTPDraftSglangProbeInternal-v67] sglang + # MiMoMTP.load_weights re-keys HF names + # "model.mtp_layers.0.{input_proj, + # token_layernorm,hidden_layernorm, + # final_layernorm}.weight" + # to internal + # "model.{input_proj,token_layernorm, + # hidden_layernorm,final_layernorm}.weight" + # (and "model.mtp_block.*" for transformer). + # /get_weights_by_name expects the INTERNAL + # name; querying the HF name returns 400. + # We probe both so at least one set returns + # data, allowing wire <-> draft comparison. _v64_pr_targets = [ "model.mtp_layers.0.input_proj.weight", "model.mtp_layers.0.token_layernorm.weight", + "model.input_proj.weight", + "model.token_layernorm.weight", + "model.hidden_layernorm.weight", + "model.final_layernorm.weight", ] for _v64_pr_addr in self.addresses: for _v64_pr_n in _v64_pr_targets: From 48325d3c0e04f8507b689df441f9b7e3d21ff433 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 20:36:39 +0800 Subject: [PATCH 140/140] feat(remote_inf_engine): fix --- areal/infra/remote_inf_engine.py | 302 +++++++++++++++++++++++++++++++ 1 file changed, 302 insertions(+) diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 84c654178d..62dca95492 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -1209,6 +1209,244 @@ def update_weights_from_tensor_serialized( ) except Exception: pass + # [MTPWireBundleAudit-v68] In sglang's update_weights_from_tensor + # protocol the megatron side calls + # per_rank = [(name, LocalSerializedTensor(values=[inner_b])) + # for name, inner_b in pairs] + # outer = MultiprocessingSerializer.serialize(per_rank) + # payload = {"serialized_named_tensors": [b64(outer)] * tp_size} + # Therefore on the wire serialized_named_tensors[0] is NOT a + # (name, blob) tuple but a single base64 string of the outer + # bundle, which the v64/v67 audits could not iterate over and + # produced n_hashed=0 in log.11. v68 base64-decodes the outer + # bundle, deserializes it via sglang's MultiprocessingSerializer, + # extracts the inner LocalSerializedTensor for each MTP target, + # deserializes that to a torch.Tensor, and logs first8/last8 + + # abs_mean/abs_max/l2 + sha256_16 of the actual fp32 wire bytes. + try: + import base64 as _v68_b64 + import hashlib as _v68_hash + import struct as _v68_struct + _v68_named_outer = serialized_payload.get( + "serialized_named_tensors", [] + ) if isinstance(serialized_payload, dict) else [] + _v68_targets = ( + "model.mtp_layers.0.token_layernorm.weight", + "model.mtp_layers.0.hidden_layernorm.weight", + "model.mtp_layers.0.input_proj.weight", + "model.mtp_layers.0.final_layernorm.weight", + "model.mtp_layers.0.input_layernorm.weight", + "model.mtp_layers.0.post_attention_layernorm.weight", + "model.mtp_layers.0.self_attn.o_proj.weight", + "model.mtp_layers.0.self_attn.q_proj.weight", + "model.mtp_layers.0.self_attn.k_proj.weight", + "model.mtp_layers.0.self_attn.v_proj.weight", + "model.mtp_layers.0.mlp.gate_proj.weight", + "model.mtp_layers.0.mlp.up_proj.weight", + "model.mtp_layers.0.mlp.down_proj.weight", + ) + _v68_handled = 0 + for _v68_outer_idx, _v68_outer_item in enumerate( + _v68_named_outer + ): + # Phase 1: decode outer base64 bundle + _v68_outer_bytes = None + if isinstance(_v68_outer_item, str): + try: + _v68_outer_bytes = _v68_b64.b64decode( + _v68_outer_item, validate=False) + except Exception as _e_v68_b: + logger.warning( + "[MTPWireBundleAudit-v68] outer_idx=%d " + "b64 decode failed: %r", + _v68_outer_idx, _e_v68_b, + ) + continue + elif isinstance(_v68_outer_item, (bytes, bytearray)): + _v68_outer_bytes = bytes(_v68_outer_item) + else: + logger.info( + "[MTPWireBundleAudit-v68] outer_idx=%d " + "unexpected_type=%s — skipping", + _v68_outer_idx, + type(_v68_outer_item).__name__, + ) + continue + if not _v68_outer_bytes: + continue + logger.info( + "[MTPWireBundleAudit-v68] outer_idx=%d " + "b64_len=%d outer_bytes=%d sha256_16=%s", + _v68_outer_idx, + len(_v68_outer_item) if isinstance( + _v68_outer_item, str) else -1, + len(_v68_outer_bytes), + _v68_hash.sha256( + _v68_outer_bytes + ).hexdigest()[:16], + ) + # Phase 2: import sglang serializer + try: + from sglang.srt.utils import ( + MultiprocessingSerializer as _v68_MPS, + ) + except Exception as _e_v68_imp_mps: + logger.warning( + "[MTPWireBundleAudit-v68] cannot import " + "MultiprocessingSerializer: %r", + _e_v68_imp_mps, + ) + continue + try: + _v68_per_rank = _v68_MPS.deserialize( + _v68_outer_bytes) + except Exception as _e_v68_dso: + logger.warning( + "[MTPWireBundleAudit-v68] outer " + "deserialize failed: %r", _e_v68_dso, + ) + continue + if not isinstance( + _v68_per_rank, (list, tuple) + ): + logger.info( + "[MTPWireBundleAudit-v68] per_rank type=%s " + "len=N/A — unexpected", + type(_v68_per_rank).__name__, + ) + continue + _v68_all_names = [] + for _v68_pair in _v68_per_rank: + if ( + isinstance(_v68_pair, (list, tuple)) + and len(_v68_pair) >= 1 + ): + _v68_all_names.append(_v68_pair[0]) + logger.info( + "[MTPWireBundleAudit-v68] per_rank n_pairs=%d " + "all_names=%s", + len(_v68_per_rank), _v68_all_names, + ) + # Phase 3: enumerate inner pairs, audit targets + for _v68_pair in _v68_per_rank: + if not ( + isinstance(_v68_pair, (list, tuple)) + and len(_v68_pair) >= 2 + ): + continue + _v68_name = _v68_pair[0] + _v68_lst = _v68_pair[1] + if _v68_name not in _v68_targets: + continue + # LocalSerializedTensor.values[0] holds inner bytes + _v68_inner = None + try: + _v68_inner = getattr( + _v68_lst, "values", None) + if ( + isinstance(_v68_inner, (list, tuple)) + and len(_v68_inner) >= 1 + ): + _v68_inner_b = _v68_inner[0] + else: + _v68_inner_b = _v68_inner + except Exception: + _v68_inner_b = None + if _v68_inner_b is None: + logger.info( + "[MTPWireBundleAudit-v68] hf_name=%s " + "NO_INNER lst_type=%s", + _v68_name, type(_v68_lst).__name__, + ) + continue + # Inner is itself a serialized torch.Tensor blob + try: + _v68_tensor = _v68_MPS.deserialize( + _v68_inner_b) + except Exception as _e_v68_dsi: + logger.warning( + "[MTPWireBundleAudit-v68] hf_name=%s " + "inner deserialize failed: %r", + _v68_name, _e_v68_dsi, + ) + continue + try: + import torch as _v68_torch + if isinstance(_v68_tensor, _v68_torch.Tensor): + _v68_t_cpu = _v68_tensor.detach().to( + "cpu", dtype=_v68_torch.float32 + ).contiguous() + _v68_flat = _v68_t_cpu.flatten() + _v68_n_el = int(_v68_flat.numel()) + _v68_first_n = min(8, _v68_n_el) + _v68_last_n = min(8, _v68_n_el) + _v68_first8 = ( + _v68_flat[:_v68_first_n].tolist() + ) + _v68_last8 = ( + _v68_flat[-_v68_last_n:].tolist() + if _v68_last_n > 0 else [] + ) + _v68_abs_mean = float( + _v68_flat.abs().mean().item() + ) if _v68_n_el > 0 else 0.0 + _v68_abs_max = float( + _v68_flat.abs().max().item() + ) if _v68_n_el > 0 else 0.0 + _v68_l2 = float( + _v68_flat.norm(p=2).item() + ) if _v68_n_el > 0 else 0.0 + _v68_raw = ( + _v68_flat.numpy().tobytes() + if _v68_n_el > 0 else b"" + ) + _v68_sha = _v68_hash.sha256( + _v68_raw + ).hexdigest()[:16] + logger.info( + "[MTPWireBundleAudit-v68] " + "hf_name=%s shape=%s dtype=%s " + "numel=%d sha256_16=%s " + "first8=%s last8=%s " + "abs_mean=%.6e abs_max=%.6e " + "l2=%.6e", + _v68_name, + tuple(_v68_tensor.shape), + str(_v68_tensor.dtype), + _v68_n_el, _v68_sha, + _v68_first8, _v68_last8, + _v68_abs_mean, _v68_abs_max, + _v68_l2, + ) + _v68_handled += 1 + else: + logger.info( + "[MTPWireBundleAudit-v68] " + "hf_name=%s inner_type=%s — not a " + "torch.Tensor", + _v68_name, + type(_v68_tensor).__name__, + ) + except Exception as _e_v68_tens: + logger.warning( + "[MTPWireBundleAudit-v68] hf_name=%s " + "tensor handling failed: %r", + _v68_name, _e_v68_tens, + ) + logger.info( + "[MTPWireBundleAudit-v68] summary handled=%d " + "outer_n=%d addresses=%s", + _v68_handled, len(_v68_named_outer), + self.addresses, + ) + except Exception as _e_v68: + try: + logger.warning( + "[MTPWireBundleAudit-v68] outer failure: %r", + _e_v68, + ) + except Exception: + pass http_req = HttpRequest( endpoint="/update_weights_from_tensor", payload=serialized_payload, @@ -1289,6 +1527,70 @@ def update_weights_from_tensor_serialized( _v64_pr_addr, _v64_pr_n, _e_v64_pr_inner, ) + # [MTPDraftSglangProbeExt-v68] previous probe + # always returned 400 for both HF and the + # mtp_layers.0-stripped names. v68 widens the + # probe to also try sglang internal MiMoMTP + # rekeys for the transformer block: + # model.mtp_block.input_layernorm.weight + # model.mtp_block.post_attention_layernorm.weight + # model.mtp_block.self_attn.{q,k,v,o}_proj.weight + # model.mtp_block.mlp.{gate,up,down}_proj.weight + # plus the truly bare names sglang may use + # ("input_proj.weight", "token_layernorm.weight", + # ...). At least one of these should succeed + # if sglang holds the MTP weights at all. + _v68_pr_extra = [ + "model.mtp_block.input_layernorm.weight", + "model.mtp_block.post_attention_layernorm.weight", + "model.mtp_block.self_attn.q_proj.weight", + "model.mtp_block.self_attn.k_proj.weight", + "model.mtp_block.self_attn.v_proj.weight", + "model.mtp_block.self_attn.o_proj.weight", + "model.mtp_block.mlp.gate_proj.weight", + "model.mtp_block.mlp.up_proj.weight", + "model.mtp_block.mlp.down_proj.weight", + "input_proj.weight", + "token_layernorm.weight", + "hidden_layernorm.weight", + "final_layernorm.weight", + ] + for _v68_pr_addr in self.addresses: + for _v68_pr_n in _v68_pr_extra: + _v68_pr_url = ( + f"http://{_v68_pr_addr}/get_weights_by_name" + ) + try: + _v68_pr_resp = _v64_pr_req.post( + _v68_pr_url, + json={ + "name": _v68_pr_n, + "truncate_size": 32, + }, + timeout=10, + ) + _v68_pr_status = int( + _v68_pr_resp.status_code) + _v68_pr_body = _v68_pr_resp.text + _v68_pr_h = _v64_pr_hash.sha256( + _v68_pr_body.encode("utf-8") + ).hexdigest()[:16] + logger.info( + "[MTPDraftSglangProbeExt-v68] addr=%s " + "name=%s status=%d body_len=%d " + "sha256_16=%s head=%.300s", + _v68_pr_addr, _v68_pr_n, + _v68_pr_status, + len(_v68_pr_body), + _v68_pr_h, _v68_pr_body, + ) + except Exception as _e_v68_pr_inner: + logger.info( + "[MTPDraftSglangProbeExt-v68] addr=%s " + "name=%s FAIL err=%r", + _v68_pr_addr, _v68_pr_n, + _e_v68_pr_inner, + ) except Exception as _e_v64_pr: try: logger.warning(