Skip to content

Commit f5e4953

Browse files
committed
move ga file from optimizer to util
1 parent 12fe4ce commit f5e4953

3 files changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,11 @@ class Profiling(BaseModel):
13831383
xprof_e2e_enable_fw_throttle_event: bool = Field(False, description="Enable FW throttle event.")
13841384
xprof_e2e_enable_fw_power_level_event: bool = Field(False, description="Enable FW power level event.")
13851385
xprof_e2e_enable_fw_thermal_event: bool = Field(False, description="Enable FW thermal event.")
1386-
profile_power_events: bool = Field(False, description="Enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.")
1386+
profile_power_events: bool = Field(
1387+
False,
1388+
description="Enable TPU-specific power/thermal profiling events."
1389+
" Defaults to False to avoid breaking GPU xplane tracing.",
1390+
)
13871391

13881392

13891393
class HloDump(BaseModel):

src/maxtext/trainers/pre_train/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag, is_decoupled
5656
from maxtext.common.gcloud_stub import vertex_tensorboard_modules
5757
from maxtext.common.metric_logger import MetricLogger, record_activation_metrics
58-
from maxtext.optimizers.gradient_accumulation import gradient_accumulation_loss_and_grad
58+
from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad
5959
from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn
6060
from maxtext.utils import exceptions
6161
from maxtext.utils import gcs_utils
File renamed without changes.

0 commit comments

Comments
 (0)