@@ -44,11 +44,13 @@ def test_no_args_constructs(self):
4444 loader = MXCheckpointLoader ()
4545 assert loader .mx_server_url is None
4646 assert loader .is_weights_preloaded () is False
47+ assert loader .is_post_transform_weights_preloaded () is False
4748
4849 def test_mx_server_url_stored (self ):
4950 loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" )
5051 assert loader .mx_server_url == "http://mx:8001"
5152 assert loader .is_weights_preloaded () is False
53+ assert loader .is_post_transform_weights_preloaded () is False
5254
5355 def test_query_timeout_stored (self ):
5456 loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" , query_timeout_s = 900 )
@@ -75,6 +77,17 @@ def test_checkpoint_format_backing_attr(self):
7577 def test_is_weights_preloaded_initial (self ):
7678 loader = MXCheckpointLoader ()
7779 assert loader .is_weights_preloaded () is False
80+ assert loader .is_post_transform_weights_preloaded () is False
81+
82+ def test_post_transform_signal_requires_p2p_and_identity_match (self ):
83+ loader = MXCheckpointLoader ()
84+ loader ._p2p_succeeded = True
85+ loader ._post_transform_weights_preloaded = True
86+ loader ._source_identity_compatible_for_last_load = False
87+ assert loader .is_post_transform_weights_preloaded () is False
88+
89+ loader ._source_identity_compatible_for_last_load = True
90+ assert loader .is_post_transform_weights_preloaded () is True
7891
7992
8093# ---------------------------------------------------------------------------
@@ -145,9 +158,11 @@ def _modelexpress_unavailable(stack):
145158
146159 @staticmethod
147160 def _upstream_raises (stack ):
161+ loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" )
162+ loader ._source_identity_compatible = MagicMock (return_value = True )
148163 fake_mx = _build_fake_modelexpress (load_weights_side_effect = RuntimeError ("boom" ))
149164 stack .enter_context (_install_fake_modelexpress (fake_mx ))
150- return (MXCheckpointLoader ( mx_server_url = "http://mx:8001" ) , {"model" : MagicMock ()})
165+ return (loader , {"model" : MagicMock ()})
151166
152167 @pytest .mark .parametrize (
153168 "trigger_id, setup" ,
@@ -196,6 +211,7 @@ def test_p2p_full_success_returns_empty_dict(self):
196211 # ``is_weights_preloaded()`` signal as "skip the standard
197212 # weight-mapping pipeline".
198213 loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" )
214+ loader ._source_identity_compatible = MagicMock (return_value = True )
199215 fake_mx = _build_fake_modelexpress (load_weights_return = {})
200216 mapping = MagicMock (name = "mapping" )
201217 model = MagicMock (name = "model" )
@@ -205,6 +221,7 @@ def test_p2p_full_success_returns_empty_dict(self):
205221
206222 assert result == {}
207223 assert loader .is_weights_preloaded () is True
224+ assert loader .is_post_transform_weights_preloaded () is False
208225
209226 # Verify the integration contract with the upstream library:
210227 # 1. Constructed MxLiveWeightLoader with our mx_server_url.
@@ -222,6 +239,7 @@ def test_mixed_success_returns_fallback_weights(self):
222239 # tensors), keep the P2P transfer and let ModelLoader merge these
223240 # tensors through the standard disk pipeline.
224241 loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" )
242+ loader ._source_identity_compatible = MagicMock (return_value = True )
225243 fallback = {"some.weight" : MagicMock ()}
226244 fake_mx = _build_fake_modelexpress (load_weights_return = fallback )
227245
@@ -232,9 +250,36 @@ def test_mixed_success_returns_fallback_weights(self):
232250 result = loader .load_weights ("/nonexistent" , mapping = MagicMock (), model = MagicMock ())
233251
234252 assert loader .is_weights_preloaded () is True
253+ assert loader .is_post_transform_weights_preloaded () is False
235254 assert result is fallback
236255 mock_super_load .assert_not_called ()
237256
257+ def test_post_transform_mixed_success_falls_back_to_full_disk_load (self ):
258+ # Wave 5 will let MX advertise post-transform sources. If such a
259+ # source only partially succeeds, merging raw fallback tensors would
260+ # force ModelLoader onto the full post-load path and double-transform
261+ # the P2P subset. Lock the safer behavior now: abandon the partial
262+ # post-transform transfer and return a full disk load instead.
263+ loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" )
264+ loader ._source_identity_compatible = MagicMock (return_value = True )
265+ loader ._source_metadata_is_post_transform = MagicMock (return_value = True )
266+ fallback = {"some.weight" : MagicMock (numel = lambda : 1 , element_size = lambda : 4 )}
267+ disk_weights = {"disk.weight" : MagicMock ()}
268+ fake_mx = _build_fake_modelexpress (load_weights_return = fallback )
269+
270+ with (
271+ _install_fake_modelexpress (fake_mx ),
272+ patch .object (
273+ HfCheckpointLoader , "load_weights" , return_value = disk_weights
274+ ) as mock_super_load ,
275+ ):
276+ result = loader .load_weights ("/nonexistent" , mapping = MagicMock (), model = MagicMock ())
277+
278+ assert result is disk_weights
279+ assert loader .is_weights_preloaded () is False
280+ assert loader .is_post_transform_weights_preloaded () is False
281+ mock_super_load .assert_called_once ()
282+
238283
239284# ---------------------------------------------------------------------------
240285# publish_as_source — env-var dance and graceful no-op
@@ -428,6 +473,7 @@ def _assert_timeout(*args, **kwargs):
428473 return {}
429474
430475 loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" )
476+ loader ._source_identity_compatible = MagicMock (return_value = True )
431477 fake_mx = _build_fake_modelexpress (load_weights_side_effect = _assert_timeout )
432478 with _install_fake_modelexpress (fake_mx ):
433479 loader .load_weights ("/nonexistent" , mapping = MagicMock (), model = MagicMock ())
@@ -439,6 +485,7 @@ def _assert_no_timeout(*args, **kwargs):
439485 return {}
440486
441487 loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" )
488+ loader ._source_identity_compatible = MagicMock (return_value = True )
442489 fake_mx = _build_fake_modelexpress (
443490 load_weights_side_effect = _assert_no_timeout ,
444491 source_instances = [MagicMock ()],
@@ -457,6 +504,7 @@ def _assert_env_timeout(*args, **kwargs):
457504 return {}
458505
459506 loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" )
507+ loader ._source_identity_compatible = MagicMock (return_value = True )
460508 fake_mx = _build_fake_modelexpress (load_weights_side_effect = _assert_env_timeout )
461509 with _install_fake_modelexpress (fake_mx ):
462510 loader .load_weights ("/nonexistent" , mapping = MagicMock (), model = MagicMock ())
@@ -468,6 +516,7 @@ def _assert_config_timeout(*args, **kwargs):
468516 return {}
469517
470518 loader = MXCheckpointLoader (mx_server_url = "http://mx:8001" , query_timeout_s = 900 )
519+ loader ._source_identity_compatible = MagicMock (return_value = True )
471520 fake_mx = _build_fake_modelexpress (load_weights_side_effect = _assert_config_timeout )
472521 with _install_fake_modelexpress (fake_mx ):
473522 loader .load_weights ("/nonexistent" , mapping = MagicMock (), model = MagicMock ())
0 commit comments