|
36 | 36 | from ..utils.import_utils import is_paddlefleet_available |
37 | 37 | from ..utils.log import logger |
38 | 38 | from ..utils.pdc_sdk import FLASH_DEVICE |
39 | | -from ..utils.tools import get_env_device, paddle_device |
| 39 | +from ..utils.tools import paddle_device |
40 | 40 | from .trainer_utils import ( |
41 | 41 | IntervalStrategy, |
42 | 42 | OptimizerNames, |
@@ -1528,6 +1528,12 @@ class TrainingArguments: |
1528 | 1528 | "help": "Enable splitting backward pass into stages to balance computation and reduce peak memory usage in model parallelism." |
1529 | 1529 | }, |
1530 | 1530 | ) |
| 1531 | + timer: bool = field( |
| 1532 | + default=False, |
| 1533 | + metadata={ |
| 1534 | + "help": "Enable timing for pipeline parallel stages to profile and optimize communication/computation overlap." |
| 1535 | + }, |
| 1536 | + ) |
1531 | 1537 | stage1_tensor_fusion: bool = field( |
1532 | 1538 | default=False, |
1533 | 1539 | metadata={ |
@@ -1945,6 +1951,7 @@ def __post_init__(self): |
1945 | 1951 | "enable_delay_scale_loss", |
1946 | 1952 | "enable_dp_comm_overlap", |
1947 | 1953 | "enable_sharding_comm_overlap", |
| 1954 | + "enable_timer", |
1948 | 1955 | "enable_release_grads", |
1949 | 1956 | "enable_clear_every_step_cache", |
1950 | 1957 | "enable_overlap_p2p_comm", |
@@ -1997,7 +2004,7 @@ def __post_init__(self): |
1997 | 2004 | "delay_scale_loss": True, # TODO[Waynezee]: remove this config in the future |
1998 | 2005 | "dp_comm_overlap": enable_dp_comm_overlap, |
1999 | 2006 | "sharding_comm_overlap": self.enable_sharding_comm_overlap, |
2000 | | - "enable_timer": get_env_device() != "xpu", |
| 2007 | + "enable_timer": self.timer, |
2001 | 2008 | "release_gradients": self.pp_release_grads or self.release_grads, |
2002 | 2009 | "overlap_p2p_comm": self.overlap_p2p_comm, |
2003 | 2010 | "clear_every_step_cache": self.clear_every_step_cache, |
@@ -2428,6 +2435,7 @@ def is_context_parallel_supported(): |
2428 | 2435 | "enable_delay_scale_loss", |
2429 | 2436 | # "enable_dp_comm_overlap", # no implementation for auto_parallel |
2430 | 2437 | # "enable_sharding_comm_overlap", # no implementation for auto_parallel |
| 2438 | + # "enable_timer", # no implementation for auto_parallel |
2431 | 2439 | # "disable_batch_p2p_comm", # no implementation for auto_parallel |
2432 | 2440 | "enable_split_backward", |
2433 | 2441 | "auto_parallel_sync_shared_params", |
|
0 commit comments