|
21 | 21 |
|
22 | 22 | import chex |
23 | 23 | from flax.experimental import nnx |
24 | | -from flax.training import train_state |
25 | 24 | import jax |
26 | 25 | import jax.numpy as jnp |
27 | 26 | import jax.sharding |
@@ -84,71 +83,36 @@ def test_diloco_training_simulation_with_mesh(self): |
84 | 83 | tx = optax.sgd(learning_rate=0.1) |
85 | 84 | rngs = nnx.Rngs(params=jax.random.key(seed=42)) |
86 | 85 | model = SimpleNNXModel(rngs=rngs) |
87 | | - graphdef, params = nnx.split(model) |
88 | 86 |
|
89 | | - if test_config.pure_nnx: |
90 | | - optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) |
91 | | - # diloco_test_state expects a TrainStateNNX instance when pure_nnx is True. |
92 | | - initial_test_state = TrainStateNNX(model, optimizer) |
| 87 | + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) |
| 88 | + # diloco_test_state expects a TrainStateNNX instance. |
| 89 | + initial_test_state = TrainStateNNX(model, optimizer) |
93 | 90 |
|
94 | | - # For NNX, train_step needs to take the TrainStateNNX and mutate it |
| 91 | + # train_step takes the TrainStateNNX and mutates it. |
95 | 92 |
|
96 | | - def _test_train_step(state, batch, prng_key: diloco.PRNGKey): |
97 | | - del prng_key |
| 93 | + def _test_train_step(state, batch, prng_key: diloco.PRNGKey): |
| 94 | + del prng_key |
98 | 95 |
|
99 | | - def loss_fn(model, batch): |
100 | | - inputs, labels = batch |
101 | | - logits = jax.vmap(model)(inputs) |
102 | | - residual = logits - labels |
103 | | - return jnp.mean(jnp.square(residual)) |
| 96 | + def loss_fn(model, batch): |
| 97 | + inputs, labels = batch |
| 98 | + logits = jax.vmap(model)(inputs) |
| 99 | + residual = logits - labels |
| 100 | + return jnp.mean(jnp.square(residual)) |
104 | 101 |
|
105 | | - loss, grads = nnx.value_and_grad(loss_fn)(state.model, batch) |
106 | | - state.optimizer.update(state.model, grads) |
107 | | - return state, loss |
108 | | - |
109 | | - else: |
110 | | - |
111 | | - def nnx_apply_fn(params, inputs): |
112 | | - model_replica = nnx.merge(graphdef, params) |
113 | | - return model_replica(inputs) |
114 | | - |
115 | | - # 2. Vmap this new wrapper function |
116 | | - vmapped_apply = jax.vmap(nnx_apply_fn, in_axes=(None, 0)) |
117 | | - |
118 | | - def _test_train_step(state: train_state.TrainState, batch, prng_key: diloco.PRNGKey): |
119 | | - """A simple MSE loss train step to enable numerics testing.""" |
120 | | - del prng_key |
121 | | - |
122 | | - def loss_fn(params, batch): |
123 | | - inputs, labels = batch |
124 | | - logits = vmapped_apply(params, inputs) |
125 | | - residual = logits - labels |
126 | | - sq_residual = jnp.square(residual) |
127 | | - msq_residual = jnp.mean(sq_residual) |
128 | | - return msq_residual |
129 | | - |
130 | | - loss, grad = jax.value_and_grad(loss_fn)(state.params, batch) |
131 | | - return state.apply_gradients(grads=grad), loss |
132 | | - |
133 | | - initial_test_state = train_state.TrainState.create( |
134 | | - apply_fn=vmapped_apply, |
135 | | - params=params, |
136 | | - tx=tx, |
137 | | - ) |
| 102 | + loss, grads = nnx.value_and_grad(loss_fn)(state.model, batch) |
| 103 | + state.optimizer.update(state.model, grads) |
| 104 | + return state, loss |
138 | 105 |
|
139 | 106 | diloco_test_state, _ = diloco.build_diloco_state(test_config, lambda: initial_test_state) |
140 | 107 | chex.assert_equal(diloco_test_state.step, 0) |
141 | | - if test_config.pure_nnx: |
142 | | - _, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...) |
| 108 | + _, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...) |
143 | 109 |
|
144 | | - # diloco_test_state.params might contain nnx.Variables instead of pure arrays. |
145 | | - # We need to unwrap them if they do. |
146 | | - diloco_params_pure = jax.tree_util.tree_map( |
147 | | - lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params |
148 | | - ) |
149 | | - chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict()) |
150 | | - else: |
151 | | - chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params) |
| 110 | + # diloco_test_state.params might contain nnx.Variables instead of pure arrays. |
| 111 | + # We need to unwrap them if they do. |
| 112 | + diloco_params_pure = jax.tree_util.tree_map( |
| 113 | + lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params |
| 114 | + ) |
| 115 | + chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict()) |
152 | 116 |
|
153 | 117 | diloco_train_step = diloco.build_diloco_train_step(test_config, _test_train_step) |
154 | 118 | inputs = jnp.array( |
@@ -196,17 +160,14 @@ def loss_fn(params, batch): |
196 | 160 | chex.assert_equal(diloco_test_state.step, 1.0) |
197 | 161 | chex.assert_equal(loss, 1.0) |
198 | 162 | # Assert no updates to the global model yet (no synchronization) |
199 | | - if test_config.pure_nnx: |
200 | | - _, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...) |
| 163 | + _, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...) |
201 | 164 |
|
202 | | - # diloco_test_state.params might contain nnx.Variables instead of pure arrays. |
203 | | - # We need to unwrap them if they do. |
204 | | - diloco_params_pure = jax.tree_util.tree_map( |
205 | | - lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params |
206 | | - ) |
207 | | - chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict()) |
208 | | - else: |
209 | | - chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params) |
| 165 | + # diloco_test_state.params might contain nnx.Variables instead of pure arrays. |
| 166 | + # We need to unwrap them if they do. |
| 167 | + diloco_params_pure = jax.tree_util.tree_map( |
| 168 | + lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params |
| 169 | + ) |
| 170 | + chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict()) |
210 | 171 |
|
211 | 172 | # Run the second step (no synchronization). |
212 | 173 | # Replica 0: |
@@ -236,17 +197,14 @@ def loss_fn(params, batch): |
236 | 197 | chex.assert_equal(diloco_test_state.step, 2.0) |
237 | 198 | chex.assert_trees_all_close(loss, 0.65) |
238 | 199 | # Assert no updates to the global model yet (no synchronization) |
239 | | - if test_config.pure_nnx: |
240 | | - _, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...) |
| 200 | + _, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...) |
241 | 201 |
|
242 | | - # diloco_test_state.params might contain nnx.Variables instead of pure arrays. |
243 | | - # We need to unwrap them if they do. |
244 | | - diloco_params_pure = jax.tree_util.tree_map( |
245 | | - lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params |
246 | | - ) |
247 | | - chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict()) |
248 | | - else: |
249 | | - chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params) |
| 202 | + # diloco_test_state.params might contain nnx.Variables instead of pure arrays. |
| 203 | + # We need to unwrap them if they do. |
| 204 | + diloco_params_pure = jax.tree_util.tree_map( |
| 205 | + lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params |
| 206 | + ) |
| 207 | + chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict()) |
250 | 208 |
|
251 | 209 | # Run the third step, which synchronizes afterwards. |
252 | 210 | # Replica 0: |
@@ -281,31 +239,21 @@ def loss_fn(params, batch): |
281 | 239 | chex.assert_trees_all_close(loss, 0.4481) |
282 | 240 | # Assert that inner and outer parameters are all equal now that |
283 | 241 | # synchronization has happened. |
284 | | - if test_config.pure_nnx: |
285 | | - _, inner_params, _ = nnx.split(diloco_test_state.inner_state.model, nnx.Param, ...) |
286 | | - inner_params_pure = jax.tree_util.tree_map( |
287 | | - lambda x: x.value if hasattr(x, "value") else x, inner_params.to_pure_dict() |
288 | | - ) |
289 | | - diloco_params_pure_3 = jax.tree_util.tree_map( |
290 | | - lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params |
291 | | - ) |
292 | | - chex.assert_trees_all_equal( |
293 | | - diloco_params_pure_3, |
294 | | - jax.tree.map(lambda arr: arr[0, ...], inner_params_pure), |
295 | | - ) |
296 | | - chex.assert_trees_all_equal( |
297 | | - diloco_params_pure_3, |
298 | | - jax.tree.map(lambda arr: arr[1, ...], inner_params_pure), |
299 | | - ) |
300 | | - else: |
301 | | - chex.assert_trees_all_equal( |
302 | | - diloco_test_state.params, |
303 | | - jax.tree.map(lambda arr: arr[0, ...], diloco_test_state.inner_state.params), |
304 | | - ) |
305 | | - chex.assert_trees_all_equal( |
306 | | - diloco_test_state.params, |
307 | | - jax.tree.map(lambda arr: arr[1, ...], diloco_test_state.inner_state.params), |
308 | | - ) |
| 242 | + _, inner_params, _ = nnx.split(diloco_test_state.inner_state.model, nnx.Param, ...) |
| 243 | + inner_params_pure = jax.tree_util.tree_map( |
| 244 | + lambda x: x.value if hasattr(x, "value") else x, inner_params.to_pure_dict() |
| 245 | + ) |
| 246 | + diloco_params_pure_3 = jax.tree_util.tree_map( |
| 247 | + lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params |
| 248 | + ) |
| 249 | + chex.assert_trees_all_equal( |
| 250 | + diloco_params_pure_3, |
| 251 | + jax.tree.map(lambda arr: arr[0, ...], inner_params_pure), |
| 252 | + ) |
| 253 | + chex.assert_trees_all_equal( |
| 254 | + diloco_params_pure_3, |
| 255 | + jax.tree.map(lambda arr: arr[1, ...], inner_params_pure), |
| 256 | + ) |
309 | 257 |
|
310 | 258 | # Run the fourth step (no synchronization). |
311 | 259 | # Replica 0: |
|
0 commit comments