|
29 | 29 | class _TinyModel(nn.Module): |
30 | 30 | def __init__(self, events, *, include_draft=False): |
31 | 31 | super().__init__() |
| 32 | + self._weights_transformed = False |
32 | 33 | self._events = events |
33 | 34 | if include_draft: |
34 | 35 | self.draft_model = nn.Module() |
@@ -131,6 +132,19 @@ def _install_gms_backend(monkeypatch, backend): |
131 | 132 | monkeypatch.setattr(memory_mod, "GMSBackend", MagicMock(return_value=backend)) |
132 | 133 |
|
133 | 134 |
|
| 135 | +class _PostTransformMxLoader: |
| 136 | + checkpoint_format = "MX" |
| 137 | + |
| 138 | + def __init__(self) -> None: |
| 139 | + self.load_weights = MagicMock(return_value={}) |
| 140 | + self.is_weights_preloaded = MagicMock(return_value=True) |
| 141 | + self.post_load_apply = MagicMock() |
| 142 | + self.post_load_publish = MagicMock() |
| 143 | + |
| 144 | + def is_post_transform_weights_preloaded(self) -> bool: |
| 145 | + return True |
| 146 | + |
| 147 | + |
134 | 148 | def _spec_config_needing_draft_weights(): |
135 | 149 | return SimpleNamespace( |
136 | 150 | spec_dec_mode=SimpleNamespace(need_load_draft_weights=lambda: True), |
@@ -347,6 +361,58 @@ def test_gms_rw_loader_preload_skips_mapping_pipeline(monkeypatch): |
347 | 361 | backend.finalize_write.assert_called_once_with(model) |
348 | 362 |
|
349 | 363 |
|
| 364 | +def test_gms_rw_mx_post_transform_preload_uses_staged_path(monkeypatch): |
| 365 | + """GMS writers that receive post-transform MX bytes must not transform again.""" |
| 366 | + events = [] |
| 367 | + loader = _make_loader(monkeypatch, events=events) |
| 368 | + monkeypatch.setattr( |
| 369 | + ModelLoader, |
| 370 | + "_MX_STAGED_RECEIVER_ALLOWLIST", |
| 371 | + frozenset({(_TinyModel, ModelLoader._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION)}), |
| 372 | + ) |
| 373 | + backend = _build_gms_backend(is_rw=True, events=events) |
| 374 | + backend.move_untracked_params.side_effect = lambda _model: events.append( |
| 375 | + "move_untracked_params" |
| 376 | + ) |
| 377 | + backend.finalize_write.side_effect = lambda _model: events.append("finalize_write") |
| 378 | + _install_gms_backend(monkeypatch, backend) |
| 379 | + |
| 380 | + checkpoint_loader = _PostTransformMxLoader() |
| 381 | + checkpoint_loader.post_load_apply.side_effect = lambda *_a, **_kw: events.append( |
| 382 | + "post_load_apply" |
| 383 | + ) |
| 384 | + checkpoint_loader.post_load_publish.side_effect = lambda *_a, **_kw: events.append( |
| 385 | + "post_load_publish" |
| 386 | + ) |
| 387 | + |
| 388 | + model, _ = loader.load("/ckpt", checkpoint_loader) |
| 389 | + |
| 390 | + assert events == [ |
| 391 | + "pool_enter", |
| 392 | + "_apply", |
| 393 | + "to", |
| 394 | + "post_load_apply", |
| 395 | + "setup_aliases", |
| 396 | + "cache_derived_state", |
| 397 | + "move_untracked_params", |
| 398 | + "post_load_publish", |
| 399 | + "pool_exit", |
| 400 | + "finalize_write", |
| 401 | + ] |
| 402 | + assert "post_load_weights" not in events |
| 403 | + assert model._weights_transformed is True |
| 404 | + _args, kwargs = checkpoint_loader.load_weights.call_args |
| 405 | + assert kwargs["allow_post_transform_weights"] is True |
| 406 | + loader._call_load_weights.assert_not_called() |
| 407 | + checkpoint_loader.post_load_publish.assert_called_once_with( |
| 408 | + model, |
| 409 | + checkpoint_dir="/ckpt", |
| 410 | + weights_preloaded=True, |
| 411 | + source_identity=loader._source_identity, |
| 412 | + ) |
| 413 | + backend.finalize_write.assert_called_once_with(model) |
| 414 | + |
| 415 | + |
350 | 416 | def test_gms_rw_no_load_and_no_preload_raises(monkeypatch): |
351 | 417 | """RW + empty ``weights`` + ``is_weights_preloaded()=False`` is a bug. |
352 | 418 |
|
|
0 commit comments