Skip to content

Commit 93b53b3

Browse files
committed
refactor
1 parent 903ef18 commit 93b53b3

7 files changed

Lines changed: 56 additions & 63 deletions

File tree

src/dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ tiktoken
4646
tokamax!=0.1.0
4747
transformers
4848
uvloop
49-
qwix
49+
qwix>=0.1.6

src/dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ cffi>=2.0.0 ; platform_python_implementation != 'PyPy'
2323
cfgv>=3.5.0
2424
charset-normalizer>=3.4.6
2525
chex>=0.1.91
26-
click>=8.3.2
26+
click>=8.3.3
2727
cloud-accelerator-diagnostics>=0.1.1
2828
cloud-tpu-diagnostics>=0.1.5
2929
cloudpickle>=3.1.2
@@ -40,14 +40,14 @@ dill>=0.4.1
4040
distlib>=0.4.0
4141
distro>=1.9.0
4242
dm-tree>=0.1.10
43-
docstring-parser>=0.17.0
43+
docstring-parser>=0.18.0
4444
drjax>=0.1.4
4545
editdistance>=0.8.1
4646
einops>=0.8.2
4747
einshape>=1.0
4848
etils>=1.14.0
4949
execnet>=2.1.2
50-
fastapi>=0.135.3
50+
fastapi>=0.136.1
5151
filelock>=3.20.3
5252
flatbuffers>=25.12.19
5353
flax>=0.12.6
@@ -61,7 +61,7 @@ google-api-python-client>=2.194.0
6161
google-auth-httplib2>=0.3.1
6262
google-auth-oauthlib>=1.3.1
6363
google-auth>=2.49.2
64-
google-cloud-aiplatform>=1.147.0
64+
google-cloud-aiplatform>=1.148.1
6565
google-cloud-appengine-logging>=1.9.0
6666
google-cloud-audit-log>=0.5.0
6767
google-cloud-bigquery>=3.41.0
@@ -73,7 +73,7 @@ google-cloud-resource-manager>=1.17.0
7373
google-cloud-storage-control>=1.11.0
7474
google-cloud-storage>=3.10.1
7575
google-crc32c>=1.8.0
76-
google-genai>=1.72.0
76+
google-genai>=1.73.1
7777
google-pasta>=0.2.0
7878
google-resumable-media>=2.8.2
7979
googleapis-common-protos>=1.74.0
@@ -88,10 +88,10 @@ hf-xet>=1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or
8888
httpcore>=1.0.9
8989
httplib2>=0.31.2
9090
httpx>=0.28.1
91-
huggingface-hub>=1.10.1
91+
huggingface-hub>=1.11.0
9292
humanize>=4.15.0
9393
hypothesis>=6.142.1
94-
identify>=2.6.18
94+
identify>=2.6.19
9595
idna>=3.11
9696
immutabledict>=4.3.1
9797
importlab>=0.8.1
@@ -155,30 +155,30 @@ opt-einsum>=3.4.0
155155
optax>=0.2.8
156156
optree>=0.19.0
157157
optype>=0.17.0
158-
orbax-checkpoint>=0.11.34
158+
orbax-checkpoint>=0.11.36
159159
orbax-export>=0.0.8
160160
packaging>=26.0
161161
pandas>=3.0.2
162162
parameterized>=0.9.0
163-
pathspec>=1.0.4
163+
pathspec>=1.1.0
164164
pathwaysutils>=0.1.7
165165
pillow>=12.1.1
166166
platformdirs>=4.9.6
167167
pluggy>=1.6.0
168168
portpicker>=1.6.0
169-
pre-commit>=4.5.1
169+
pre-commit>=4.6.0
170170
promise>=2.3
171171
propcache>=0.4.1
172172
proto-plus>=1.27.2
173173
protobuf>=6.33.6
174174
psutil>=7.2.2
175-
pyarrow>=23.0.1
175+
pyarrow>=24.0.0
176176
pyasn1-modules>=0.4.2
177177
pyasn1>=0.6.3
178178
pycnite>=2024.7.31
179179
pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
180-
pydantic-core>=2.46.0
181-
pydantic>=2.13.0
180+
pydantic-core>=2.46.3
181+
pydantic>=2.13.3
182182
pydot>=4.0.1
183183
pyelftools>=0.32
184184
pyglove>=0.4.5
@@ -193,7 +193,7 @@ python-dateutil>=2.9.0.post0
193193
pytokens>=0.4.1
194194
pytype>=2024.10.11
195195
pyyaml>=6.0.3
196-
qwix>=0.1.5
196+
qwix>=0.1.6
197197
regex>=2026.4.4
198198
requests-oauthlib>=2.0.0
199199
requests>=2.32.5
@@ -206,7 +206,7 @@ seqio>=0.0.20
206206
setuptools>=82.0.1
207207
shellingham>=1.5.4
208208
simple-parsing>=0.1.8
209-
simplejson>=3.20.2
209+
simplejson>=4.1.0
210210
six>=1.17.0
211211
sniffio>=1.3.1
212212
sortedcontainers>=2.4.0
@@ -234,17 +234,17 @@ tqdm>=4.67.3
234234
transformer-engine-cu12>=2.13.0
235235
transformer-engine-jax>=2.13.0
236236
transformer-engine>=2.13.0
237-
transformers>=5.5.4
237+
transformers>=5.6.1
238238
treescope>=0.1.10
239239
typeguard>=2.13.3
240-
typer>=0.24.1
240+
typer>=0.24.2
241241
typing-extensions>=4.15.0
242242
typing-inspect>=0.9.0
243243
typing-inspection>=0.4.2
244244
tzdata>=2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32'
245245
uritemplate>=4.2.0
246246
urllib3>=2.6.3
247-
uvicorn>=0.44.0
247+
uvicorn>=0.46.0
248248
uvloop>=0.22.1
249249
virtualenv>=20.36.1
250250
wadler-lindig>=0.1.7

src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ pytype>=2024.10.11
303303
pytz>=2025.2
304304
PyYAML>=6.0.3
305305
pyzmq>=27.1.0
306-
qwix>=0.1.4
306+
qwix>=0.1.6
307307
ray>=2.54.0
308308
referencing>=0.37.0
309309
regex>=2025.11.3

src/dependencies/requirements/generated_requirements/tpu-requirements.txt

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ cffi>=2.0.0 ; platform_python_implementation != 'PyPy'
2323
cfgv>=3.5.0
2424
charset-normalizer>=3.4.6
2525
chex>=0.1.91
26-
click>=8.3.2
26+
click>=8.3.3
2727
cloud-accelerator-diagnostics>=0.1.1
2828
cloud-tpu-diagnostics>=0.1.5
2929
cloudpickle>=3.1.2
@@ -39,14 +39,14 @@ dill>=0.4.1
3939
distlib>=0.4.0
4040
distro>=1.9.0
4141
dm-tree>=0.1.10
42-
docstring-parser>=0.17.0
42+
docstring-parser>=0.18.0
4343
drjax>=0.1.4
4444
editdistance>=0.8.1
4545
einops>=0.8.2
4646
einshape>=1.0
4747
etils>=1.14.0
4848
execnet>=2.1.2
49-
fastapi>=0.135.3
49+
fastapi>=0.136.0
5050
filelock>=3.20.3
5151
flatbuffers>=25.12.19
5252
flax>=0.12.6
@@ -60,7 +60,7 @@ google-api-python-client>=2.194.0
6060
google-auth-httplib2>=0.3.1
6161
google-auth-oauthlib>=1.3.1
6262
google-auth>=2.49.2
63-
google-cloud-aiplatform>=1.147.0
63+
google-cloud-aiplatform>=1.148.1
6464
google-cloud-appengine-logging>=1.9.0
6565
google-cloud-audit-log>=0.5.0
6666
google-cloud-bigquery>=3.41.0
@@ -72,7 +72,7 @@ google-cloud-resource-manager>=1.17.0
7272
google-cloud-storage-control>=1.11.0
7373
google-cloud-storage>=3.10.1
7474
google-crc32c>=1.8.0
75-
google-genai>=1.72.0
75+
google-genai>=1.73.1
7676
google-pasta>=0.2.0
7777
google-resumable-media>=2.8.2
7878
googleapis-common-protos>=1.74.0
@@ -87,10 +87,10 @@ hf-xet>=1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or
8787
httpcore>=1.0.9
8888
httplib2>=0.31.2
8989
httpx>=0.28.1
90-
huggingface-hub>=1.10.1
90+
huggingface-hub>=1.11.0
9191
humanize>=4.15.0
9292
hypothesis>=6.142.1
93-
identify>=2.6.18
93+
identify>=2.6.19
9494
idna>=3.11
9595
immutabledict>=4.3.1
9696
importlab>=0.8.1
@@ -140,30 +140,30 @@ opt-einsum>=3.4.0
140140
optax>=0.2.8
141141
optree>=0.19.0
142142
optype>=0.17.0
143-
orbax-checkpoint>=0.11.34
143+
orbax-checkpoint>=0.11.36
144144
orbax-export>=0.0.8
145145
packaging>=26.0
146146
pandas>=3.0.2
147147
parameterized>=0.9.0
148-
pathspec>=1.0.4
148+
pathspec>=1.1.0
149149
pathwaysutils>=0.1.7
150150
pillow>=12.1.1
151151
platformdirs>=4.9.6
152152
pluggy>=1.6.0
153153
portpicker>=1.6.0
154-
pre-commit>=4.5.1
154+
pre-commit>=4.6.0
155155
promise>=2.3
156156
propcache>=0.4.1
157157
proto-plus>=1.27.2
158158
protobuf>=6.33.6
159159
psutil>=7.2.2
160-
pyarrow>=23.0.1
160+
pyarrow>=24.0.0
161161
pyasn1-modules>=0.4.2
162162
pyasn1>=0.6.3
163163
pycnite>=2024.7.31
164164
pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
165-
pydantic-core>=2.46.0
166-
pydantic>=2.13.0
165+
pydantic-core>=2.46.3
166+
pydantic>=2.13.3
167167
pydot>=4.0.1
168168
pyelftools>=0.32
169169
pyglove>=0.4.5
@@ -191,7 +191,7 @@ seqio>=0.0.20
191191
setuptools>=82.0.1
192192
shellingham>=1.5.4
193193
simple-parsing>=0.1.8
194-
simplejson>=3.20.2
194+
simplejson>=4.1.0
195195
six>=1.17.0
196196
sniffio>=1.3.1
197197
sortedcontainers>=2.4.0
@@ -216,17 +216,17 @@ toml>=0.10.2
216216
tomlkit>=0.14.0
217217
toolz>=1.1.0
218218
tqdm>=4.67.3
219-
transformers>=5.5.4
219+
transformers>=5.6.1
220220
treescope>=0.1.10
221221
typeguard>=2.13.3
222-
typer>=0.24.1
222+
typer>=0.24.2
223223
typing-extensions>=4.15.0
224224
typing-inspect>=0.9.0
225225
typing-inspection>=0.4.2
226226
tzdata>=2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32'
227227
uritemplate>=4.2.0
228228
urllib3>=2.6.3
229-
uvicorn>=0.44.0
229+
uvicorn>=0.46.0
230230
uvloop>=0.22.1
231231
virtualenv>=20.36.1
232232
wadler-lindig>=0.1.7

src/dependencies/requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pyink
3232
pylint
3333
pytest
3434
pytype
35-
qwix
35+
qwix>=0.1.6
3636
sentencepiece
3737
tensorboard-plugin-profile
3838
tensorboardx

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,7 @@ def gmm(
959959
use_qwix_quantization=config.use_qwix_quantization,
960960
use_tokamax_backend=config.use_tokamax_gmm,
961961
weight_gather_axes=weight_gather_axes,
962-
qwix_rule=quantizations.get_fp8_full_qwix_rule(config),
962+
qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config),
963963
)
964964
else:
965965
output = tokamax.ragged_dot(

src/maxtext/trainers/pre_train/train.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,9 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr
127127
rng1, aqt_rng = jax.random.split(dropout_rng)
128128

129129
# Flax Linen model
130-
if sparsity_enabled:
131-
model_vars = {"params": params}
132-
else:
133-
model_vars = params
134-
135-
if sparsity_state and sparsity_enabled:
130+
model_vars = {"params": params}
131+
if sparsity_state:
136132
model_vars["batch_stats"] = sparsity_state
137-
138133
logits, intermediate_outputs = model.apply(
139134
model_vars,
140135
data["inputs"],
@@ -341,16 +336,20 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
341336
params,
342337
params_shardings,
343338
)
344-
sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m
345-
pure_params = params["params"] if sparsity_enabled else params
339+
pure_params = params["params"] if "params" in params else params
346340
batch_stats = params.get("batch_stats", {})
347341

348342
grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True)
349-
350-
kwargs = {"is_train": True}
351-
if sparsity_enabled:
352-
kwargs["sparsity_state"] = batch_stats
353-
(loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, pure_params, *extra_dpo_args, kwargs)
343+
(loss, aux), raw_grads = grad_func(
344+
model,
345+
config,
346+
data,
347+
dropout_rng,
348+
pure_params,
349+
*extra_dpo_args,
350+
sparsity_state=batch_stats,
351+
is_train=True,
352+
)
354353

355354
raw_grads = jax.tree_util.tree_map(
356355
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,
@@ -425,10 +424,9 @@ def move(path, value):
425424
)
426425
)
427426
# Re-wrap grads to match state.params structure if it's a dict of collections
428-
sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m
429-
if sparsity_enabled:
427+
if isinstance(state.params, dict) and "params" in state.params:
430428
full_grads = {"params": grads}
431-
if sparsity_enabled and "batch_stats" in state.params:
429+
if "batch_stats" in state.params:
432430
batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {}))
433431
full_grads["batch_stats"] = batch_stats_grads
434432
full_grads = max_utils.unbox_logicallypartioned(full_grads)
@@ -461,7 +459,6 @@ def move(path, value):
461459
and "batch_stats" in state.params
462460
)
463461

464-
jax.debug.print("amanda has_batch_stats: {s}", s=has_batch_stats)
465462
if has_batch_stats:
466463
new_params = dict(new_state.params)
467464
new_params["batch_stats"] = max_utils.unbox_logicallypartioned(aux["batch_stats"])
@@ -524,15 +521,11 @@ def eval_step(model, config, state, data, dropout_rng):
524521
extra_dpo_args = [reference_params]
525522
_loss_fn = dpo_loss_fn
526523

527-
sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m
528-
pure_params = state.params["params"] if sparsity_enabled else state.params
524+
pure_params = state.params["params"] if "params" in state.params else state.params
529525
batch_stats = state.params.get("batch_stats", {})
530526

531527
eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False)
532-
kwargs = {}
533-
if sparsity_enabled:
534-
kwargs["sparsity_state"] = batch_stats
535-
loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, **kwargs)
528+
loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats)
536529

537530
mtp_acceptance_rate = 0.0
538531
if config.mtp_eval_target_module > 0:

0 commit comments

Comments
 (0)