-
Notifications
You must be signed in to change notification settings - Fork 443
fix: layerwise mse fix for fuse-experts MoE #1504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d63f7cf
d15e5f4
eccb12a
a79a157
bcc083b
487bb23
6f0feae
67a2b94
f690277
19f705a
4b66dc1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -227,12 +227,21 @@ def _patched_forward(self, *args, **kwargs): | |||||||
| f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." | ||||||||
| ) | ||||||||
| real_args, real_kwargs = info.cached_inputs.popleft() | ||||||||
| # Move CPU-resident captured inputs back to the layer's device for replay. | ||||||||
| device = get_module_device(self) | ||||||||
| real_args = _move_to_device(real_args, device) | ||||||||
| real_kwargs = _move_to_device(real_kwargs, device) | ||||||||
| output = self._original_forward(*real_args, **real_kwargs) | ||||||||
| info.output_meta = LayerActivationCollector._extract_output_meta(output) | ||||||||
| return output | ||||||||
|
|
||||||||
| if info.mode == "capture": | ||||||||
| info.collected_inputs.append((args, kwargs)) | ||||||||
| # Offload to CPU so the per-layer compute device doesn't OOM | ||||||||
| # while accumulating thousands of batches; "run" moves back. | ||||||||
| cpu = torch.device("cpu") | ||||||||
| info.collected_inputs.append( | ||||||||
| (_move_to_device(args, cpu), _move_to_device(kwargs, cpu)) | ||||||||
| ) | ||||||||
| raise _EarlyStopForwardError() | ||||||||
|
Comment on lines
+230
to
245
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION Performance] The capture and run modes now do a synchronous host↔device transfer per batch, per layer. With N batches × L layers, that's
|
||||||||
|
|
||||||||
| return self._original_forward(*args, **kwargs) | ||||||||
|
|
@@ -433,6 +442,8 @@ def cache_outputs_for_next_layer_calib( | |||||||
|
|
||||||||
| next_layer = self._decoder_layers[next_idx] | ||||||||
| with persistent_materialization(layer): | ||||||||
| # Free cached-but-unused GPU memory left over from the previous layer's calibration. | ||||||||
| torch.cuda.empty_cache() | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
you could trigger a manual garbage collection to delete any unused tensors
Suggested change
|
||||||||
| return self.get_input_activations(next_layer, forward_loop) | ||||||||
|
|
||||||||
|
|
||||||||
|
|
@@ -512,10 +523,9 @@ def _save_layer( | |||||||
|
|
||||||||
|
|
||||||||
| def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None: | ||||||||
| """Detect where to resume from an existing checkpoint directory. | ||||||||
| """Return ``(start_layer, manifest)`` for an existing checkpoint, else ``None``. | ||||||||
|
|
||||||||
| Returns ``(start_layer, manifest)`` if there is work to resume, | ||||||||
| or ``None`` if the directory is empty, corrupt, or calibration was already complete. | ||||||||
| ``start_layer == num_layers`` signals a fully-completed run. | ||||||||
| """ | ||||||||
| manifest = _read_manifest(checkpoint_dir) | ||||||||
| if manifest is None: | ||||||||
|
|
@@ -524,8 +534,6 @@ def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None: | |||||||
| total = manifest.get("num_layers") | ||||||||
| if last is None or total is None: | ||||||||
| return None | ||||||||
| if last + 1 >= total: | ||||||||
| return None | ||||||||
| return (last + 1, manifest) | ||||||||
|
Comment on lines
525
to
537
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] Now that A two-line bound check here would fail fast with a clear "checkpoint corrupt" message: if not isinstance(last, int) or not isinstance(total, int):
return None
if last < -1 or last >= total:
return None
return (last + 1, manifest)(CodeRabbit raised the same point — flagging here because with this PR's "completed" semantics it becomes more relevant: prior to this PR, the |
||||||||
|
|
||||||||
|
|
||||||||
|
|
@@ -568,7 +576,9 @@ def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _Checkpoint | |||||||
| f"but model has {num_layers}. Use a fresh checkpoint directory." | ||||||||
| ) | ||||||||
| start = info[0] if info else 0 | ||||||||
| if start > 0: | ||||||||
| if start >= num_layers: | ||||||||
| print_rank_0(f"Checkpoint: all {num_layers} layers already calibrated") | ||||||||
| elif start > 0: | ||||||||
| print_rank_0( | ||||||||
| f"Checkpoint: resuming layerwise calibration from layer {start}/{num_layers}" | ||||||||
| ) | ||||||||
|
|
@@ -601,8 +611,7 @@ def setup_resume(self, layers: nn.ModuleList) -> list | None: | |||||||
| raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}") | ||||||||
| # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied | ||||||||
| next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False) | ||||||||
| resume_device = get_module_device(layers[self.start_layer]) | ||||||||
| next_inputs = _move_to_device(next_inputs, resume_device) | ||||||||
| # Keep on CPU — _patched_forward's run mode moves each entry to device on pop. | ||||||||
| return next_inputs | ||||||||
|
|
||||||||
| def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: | ||||||||
|
|
@@ -620,23 +629,23 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: | |||||||
| layer = layers[i] | ||||||||
| d = _layer_dir(self.checkpoint_dir, i) | ||||||||
|
|
||||||||
| # Resolve layer_device and load inside the context so params are | ||||||||
| # materialized — otherwise get_module_device can return meta. | ||||||||
| # Load inside the context so params are materialized — otherwise | ||||||||
| # get_module_device can return meta. | ||||||||
| with enable_weight_access_and_writeback(layer, model, name_to_module): | ||||||||
| layer_device = get_module_device(layer) | ||||||||
| # weights_only=False is safe: files are internally generated by _save_layer | ||||||||
| # Load to CPU to avoid serialized-view storage_offset hazards on later clone/deepcopy. | ||||||||
| # weights_only=False is safe: files are internally generated by _save_layer. | ||||||||
| qstate = torch.load( | ||||||||
| os.path.join(d, "quantizer_state.pt"), | ||||||||
| map_location=layer_device, | ||||||||
| map_location="cpu", | ||||||||
| weights_only=False, | ||||||||
| ) | ||||||||
| weights = torch.load( | ||||||||
| os.path.join(d, "weights.pt"), | ||||||||
| map_location=layer_device, | ||||||||
| map_location="cpu", | ||||||||
| weights_only=False, | ||||||||
| ) | ||||||||
| restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate}) | ||||||||
| layer.load_state_dict(weights, strict=False, assign=True) | ||||||||
| layer.load_state_dict(weights, strict=False, assign=False) | ||||||||
|
realAsma marked this conversation as resolved.
|
||||||||
|
|
||||||||
| print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") | ||||||||
|
|
||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -719,3 +719,56 @@ def test_mtq_quantize_layerwise_raises_for_unsupported_algorithm(): | |
| config, | ||
| forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))), | ||
| ) | ||
|
|
||
|
|
||
| # Checkpoint resume + capture-time CPU offload | ||
|
|
||
|
|
||
| def test_collected_inputs_are_cpu_at_capture(monkeypatch): | ||
| """Captured inputs must be CPU-resident — the OOM-prevention invariant.""" | ||
| _register_test_discoverer(monkeypatch) | ||
| model = _SimpleTwoLayerModel(dim=8) | ||
| collector = LayerActivationCollector(model) | ||
|
|
||
| def forward_loop(m): | ||
| m(torch.randn(2, 8)) | ||
|
|
||
| collector._patch_all_layers() | ||
| try: | ||
| inputs = collector.get_input_activations(model.layers[0], forward_loop) | ||
| finally: | ||
| collector._unpatch_all_layers() | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This asserts the capture-side invariant but doesn't exercise the calibration replay. The OOM fix's correctness for the calib_func path (CPU-stashed inputs eventually reaching a GPU layer's forward) isn't covered — both layerwise e2e tests run on CPU-only models so a CPU-arg / GPU-weight mismatch would never surface. A small GPU smoke test (or even a CPU test that asserts inputs are device-correct at the moment |
||
| args, _ = inputs[0] | ||
| assert args[0].device.type == "cpu", "captured tensor must be CPU-resident" | ||
|
|
||
|
|
||
| def test_layerwise_calibrate_early_returns_on_completed_checkpoint(monkeypatch, tmp_path): | ||
| """Fully-completed checkpoint must short-circuit calibration: no forward_loop calls. | ||
|
|
||
| Indirectly covers ``detect_resume_point`` returning ``(num_layers, manifest)`` | ||
| for a completed run — if it returned ``None``, the loop would re-run and | ||
| forward_loop would be invoked. | ||
| """ | ||
| _register_test_discoverer(monkeypatch) | ||
| torch.manual_seed(0) | ||
| ckpt_dir = str(tmp_path / "ckpt") | ||
| config = _int8_layerwise_config( | ||
| {"method": "max", "layerwise": True, "layerwise_checkpoint_dir": ckpt_dir} | ||
| ) | ||
| calib_data = [torch.randint(0, 32, (2, 8))] | ||
|
|
||
| # First run writes a complete checkpoint. | ||
| model = _SimpleTransformerModel(n_layers=2, dim=16) | ||
| mtq.quantize(model, config, forward_loop=lambda m: [m(b) for b in calib_data]) | ||
|
|
||
| # Second run against the same dir must skip the calibration forward loop. | ||
| call_count = {"n": 0} | ||
|
|
||
| def counting_forward(m): | ||
| call_count["n"] += 1 | ||
| m(calib_data[0]) | ||
|
|
||
| fresh = _SimpleTransformerModel(n_layers=2, dim=16) | ||
| mtq.quantize(fresh, config, forward_loop=counting_forward) | ||
| assert call_count["n"] == 0 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Captured inputs are now CPU-resident, but
_patched_forwardonly moves them back to the layer's device in the"run"branch. The actual calibration forward inlayerwise_calibrateis_layer_forward_loop, which callsm(*args, **kwargs_input)while the layer is in"original"mode — no device move happens there. With HF accelerate hooks the per-modulepre_forwardsilently re-bounces CPU args onto the right GPU, which is presumably why the GLM5.1 e2e run was clean. On a non-accelerate single-GPU model (params on CUDA, args on CPU)_original_forwardwill hit a device mismatch.Consider either (a) moving tensors back to
get_module_device(layer)inside_layer_forward_loopmirroring therun-mode code on line ~232, or (b) adding an explicit comment + assertion that this path requires accelerate-style hooks, so the assumption is at least documented.