Skip to content

Commit dccf241

Browse files
authored
Revert "Remove timer training argument" (#4505)
1 parent a8ab43b commit dccf241

4 files changed

Lines changed: 20 additions & 12 deletions

File tree

examples/experiments/ernie_pretrain/ernie/pretrain.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,12 @@ def main():
347347
if getattr(config.trainer_args, "dp_comm_overlap", False):
348348
logger.warning("Pipeline dp_comm_overlap and FusedLinearWithGradAdd can not be used at the same time.")
349349

350-
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import (
351-
PipelineParallel,
352-
)
350+
if getattr(config.trainer_args, "timer", False):
351+
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import (
352+
PipelineParallel,
353+
)
353354

354-
PipelineParallel.timer_printer = lambda _: None
355+
PipelineParallel.timer_printer = lambda _: None
355356

356357
def formatv(v):
357358
if isinstance(v, ListConfig):

paddleformers/cli/train/ernie_pretrain/workflow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,12 @@ def run_ernie_pretrain(model_args, data_args, generating_args, training_args):
356356
if getattr(training_args, "dp_comm_overlap", False):
357357
logger.warning("Pipeline dp_comm_overlap and FusedLinearWithGradAdd can not be used at the same time.")
358358

359-
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import (
360-
PipelineParallel,
361-
)
359+
if getattr(training_args, "timer", False):
360+
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import (
361+
PipelineParallel,
362+
)
362363

363-
PipelineParallel.timer_printer = lambda _: None
364+
PipelineParallel.timer_printer = lambda _: None
364365

365366
def formatv(v):
366367
if isinstance(v, ListConfig):

paddleformers/trainer/trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@
8282
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
8383
DygraphShardingOptimizerV2,
8484
)
85-
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import PipelineParallel
8685
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
8786
fused_allreduce_gradients,
8887
)
@@ -446,7 +445,6 @@ def __init__(
446445

447446
set_profile_timers(self.timers)
448447
self.runtime_timer = RuntimeTimer("RuntimeTimer")
449-
PipelineParallel.timer_printer = lambda _: None
450448

451449
self.model_wrapped = model
452450
self.model = model

paddleformers/trainer/training_args.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from ..utils.import_utils import is_paddlefleet_available
3737
from ..utils.log import logger
3838
from ..utils.pdc_sdk import FLASH_DEVICE
39-
from ..utils.tools import get_env_device, paddle_device
39+
from ..utils.tools import paddle_device
4040
from .trainer_utils import (
4141
IntervalStrategy,
4242
OptimizerNames,
@@ -1528,6 +1528,12 @@ class TrainingArguments:
15281528
"help": "Enable splitting backward pass into stages to balance computation and reduce peak memory usage in model parallelism."
15291529
},
15301530
)
1531+
timer: bool = field(
1532+
default=False,
1533+
metadata={
1534+
"help": "Enable timing for pipeline parallel stages to profile and optimize communication/computation overlap."
1535+
},
1536+
)
15311537
stage1_tensor_fusion: bool = field(
15321538
default=False,
15331539
metadata={
@@ -1945,6 +1951,7 @@ def __post_init__(self):
19451951
"enable_delay_scale_loss",
19461952
"enable_dp_comm_overlap",
19471953
"enable_sharding_comm_overlap",
1954+
"enable_timer",
19481955
"enable_release_grads",
19491956
"enable_clear_every_step_cache",
19501957
"enable_overlap_p2p_comm",
@@ -1997,7 +2004,7 @@ def __post_init__(self):
19972004
"delay_scale_loss": True, # TODO[Waynezee]: remove this config in the future
19982005
"dp_comm_overlap": enable_dp_comm_overlap,
19992006
"sharding_comm_overlap": self.enable_sharding_comm_overlap,
2000-
"enable_timer": get_env_device() != "xpu",
2007+
"enable_timer": self.timer,
20012008
"release_gradients": self.pp_release_grads or self.release_grads,
20022009
"overlap_p2p_comm": self.overlap_p2p_comm,
20032010
"clear_every_step_cache": self.clear_every_step_cache,
@@ -2428,6 +2435,7 @@ def is_context_parallel_supported():
24282435
"enable_delay_scale_loss",
24292436
# "enable_dp_comm_overlap", # no implementation for auto_parallel
24302437
# "enable_sharding_comm_overlap", # no implementation for auto_parallel
2438+
# "enable_timer", # no implementation for auto_parallel
24312439
# "disable_batch_p2p_comm", # no implementation for auto_parallel
24322440
"enable_split_backward",
24332441
"auto_parallel_sync_shared_params",

0 commit comments

Comments
 (0)