Skip to content

Commit a79b0b7

Browse files
Merge pull request #23 from mlcommons/baseline_update
fixes to nadamw baselines
2 parents 900217a + 6acca16 commit a79b0b7

2 files changed

Lines changed: 40 additions & 60 deletions

File tree

submissions/self_tuning/nadamw/submission.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import jax
1717
import jax.numpy as jnp
1818
import optax
19-
from flax import jax_utils
2019

2120
from algoperf import jax_sharding_utils, spec
2221

@@ -212,7 +211,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
212211
)
213212
optimizer_state = opt_init_fn(params_zeros_like)
214213

215-
return jax_utils.replicate(optimizer_state), opt_update_fn
214+
return optimizer_state, opt_update_fn
216215

217216

218217
def train_step(
@@ -307,12 +306,9 @@ def update_params(
307306
dropout_rate = hyperparameters['dropout_rate']
308307

309308
# Create shardings for each argument
310-
mesh = jax.sharding.Mesh(jax.devices(), ('batch'))
311-
replicated = jax_sharding_utils.get_replicate_sharding(
312-
mesh
313-
) # No partitioning
314-
sharded = jax_sharding_utils.get_batch_sharding(
315-
mesh
309+
replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning
310+
sharded = (
311+
jax_sharding_utils.get_batch_dim_sharding()
316312
) # Partition along batch dimension
317313

318314
# Create the sharding rules for each argument
@@ -344,29 +340,21 @@ def update_params(
344340
in_shardings=arg_shardings,
345341
out_shardings=out_shardings,
346342
)
347-
outputs = jitted_train_step(
348-
workload,
349-
opt_update_fn,
350-
model_state,
351-
optimizer_state,
352-
current_param_container,
353-
batch,
354-
rng,
355-
grad_clip,
356-
label_smoothing,
357-
dropout_rate,
358-
)
359-
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
360-
361-
# Log loss, grad_norm.
362-
if global_step % 100 == 0 and workload.metrics_logger is not None:
363-
workload.metrics_logger.append_scalar_metrics(
364-
{
365-
'loss': loss[0],
366-
'grad_norm': grad_norm[0],
367-
},
368-
global_step,
343+
344+
new_optimizer_state, new_params, new_model_state, loss, grad_norm = (
345+
jitted_train_step(
346+
workload,
347+
opt_update_fn,
348+
model_state,
349+
optimizer_state,
350+
current_param_container,
351+
batch,
352+
rng,
353+
grad_clip,
354+
label_smoothing,
355+
dropout_rate,
369356
)
357+
)
370358
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
371359

372360

@@ -415,6 +403,8 @@ def get_batch_size(workload_name):
415403
return 512
416404
elif workload_name == 'wmt':
417405
return 128
406+
elif workload_name == 'finewebedu_lm':
407+
return 64
418408
elif workload_name == 'mnist':
419409
return 16
420410
else:

submissions/self_tuning/nadamw_baselinev05/submission.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import jax
1717
import jax.numpy as jnp
1818
import optax
19-
from flax import jax_utils
2019

2120
from algoperf import jax_sharding_utils, spec
2221

@@ -212,7 +211,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
212211
)
213212
optimizer_state = opt_init_fn(params_zeros_like)
214213

215-
return jax_utils.replicate(optimizer_state), opt_update_fn
214+
return optimizer_state, opt_update_fn
216215

217216

218217
def train_step(
@@ -307,12 +306,9 @@ def update_params(
307306
dropout_rate = hyperparameters['dropout_rate']
308307

309308
# Create shardings for each argument
310-
mesh = jax.sharding.Mesh(jax.devices(), ('batch'))
311-
replicated = jax_sharding_utils.get_replicate_sharding(
312-
mesh
313-
) # No partitioning
314-
sharded = jax_sharding_utils.get_batch_sharding(
315-
mesh
309+
replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning
310+
sharded = (
311+
jax_sharding_utils.get_batch_dim_sharding()
316312
) # Partition along batch dimension
317313

318314
# Create the sharding rules for each argument
@@ -344,29 +340,21 @@ def update_params(
344340
in_shardings=arg_shardings,
345341
out_shardings=out_shardings,
346342
)
347-
outputs = jitted_train_step(
348-
workload,
349-
opt_update_fn,
350-
model_state,
351-
optimizer_state,
352-
current_param_container,
353-
batch,
354-
rng,
355-
grad_clip,
356-
label_smoothing,
357-
dropout_rate,
358-
)
359-
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
360-
361-
# Log loss, grad_norm.
362-
if global_step % 100 == 0 and workload.metrics_logger is not None:
363-
workload.metrics_logger.append_scalar_metrics(
364-
{
365-
'loss': loss[0],
366-
'grad_norm': grad_norm[0],
367-
},
368-
global_step,
343+
344+
new_optimizer_state, new_params, new_model_state, loss, grad_norm = (
345+
jitted_train_step(
346+
workload,
347+
opt_update_fn,
348+
model_state,
349+
optimizer_state,
350+
current_param_container,
351+
batch,
352+
rng,
353+
grad_clip,
354+
label_smoothing,
355+
dropout_rate,
369356
)
357+
)
370358
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
371359

372360

@@ -415,6 +403,8 @@ def get_batch_size(workload_name):
415403
return 512
416404
elif workload_name == 'wmt':
417405
return 128
406+
elif workload_name == 'finewebedu_lm':
407+
return 64
418408
elif workload_name == 'mnist':
419409
return 16
420410
else:

0 commit comments

Comments
 (0)