|
140 | 140 | DefaultFlowCallback, |
141 | 141 | PrinterCallback, |
142 | 142 | ProgressCallback, |
| 143 | + SPGradSyncCallback, |
143 | 144 | TrainerCallback, |
144 | 145 | TrainerControl, |
145 | 146 | TrainerState, |
@@ -444,9 +445,8 @@ def _save_ckpt_func(state_dict, path, signal_path=None): |
444 | 445 | ), "should_save_sharding_stage1_model should be True when using zero cost checkpoint" |
445 | 446 | assert ( |
446 | 447 | ShardingOption.FULL_SHARD not in self.args.sharding |
447 | | - ), "FULL_SHARD is not supported when using zero cost checkpoint" |
448 | | - assert not self.args.save_tokenizer, "save_tokenizer is not supported when using zero cost checkpoint" |
449 | | - assert not self.args.save_rng_states, "save_rng_states is not supported when using zero cost checkpoint" |
| 448 | + ), "FULL_SHARD is not supported when using flash save mode" |
| 449 | + assert not self.args.save_tokenizer, "save_tokenizer is not supported when using flash save mode" |
450 | 450 |
|
451 | 451 | # init attributes for zero cost checkpoint mode |
452 | 452 | self.zcc_manager = None |
@@ -2021,34 +2021,18 @@ def _load_rng_state(self, checkpoint): |
2021 | 2021 | if checkpoint is None: |
2022 | 2022 | return |
2023 | 2023 |
|
2024 | | - # if use distributed training |
2025 | | - if self.args.world_size > 1: |
2026 | | - process_index = self.args.process_index |
2027 | | - rng_file_list = [None for x in range(self.args.world_size)] |
2028 | | - if self.args.should_save: |
2029 | | - rng_file = os.path.join(checkpoint, f"rng_state_{self.args.world_size}.pth") |
2030 | | - if os.path.isfile(rng_file): |
2031 | | - rng_file_list = paddle.load(rng_file, return_numpy=True) |
2032 | | - paddle.distributed.broadcast_object_list(rng_file_list, src=0) |
2033 | | - # if rng_file_list still empty, not log rng state. |
2034 | | - if rng_file_list[0] is None: |
2035 | | - logger.info( |
2036 | | - f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " |
2037 | | - "wasn't launched in a distributed fashion, reproducibility is not guaranteed." |
2038 | | - ) |
2039 | | - return |
2040 | | - else: |
2041 | | - checkpoint_rng_state = rng_file_list[process_index] |
2042 | | - else: |
2043 | | - rng_file = os.path.join(checkpoint, "rng_state.pth") |
2044 | | - if not os.path.isfile(rng_file): |
2045 | | - logger.info( |
2046 | | - "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " |
2047 | | - "fashion, reproducibility is not guaranteed." |
2048 | | - ) |
2049 | | - return |
| 2024 | + rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth") |
| 2025 | + if not os.path.isfile(rng_file): |
| 2026 | + logger.info( |
| 2027 | + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " |
| 2028 | + "fashion, reproducibility is not guaranteed." |
| 2029 | + ) |
| 2030 | + return |
2050 | 2031 |
|
2051 | | - checkpoint_rng_state = paddle.load(rng_file, return_numpy=True) |
| 2032 | + checkpoint_rng_state = paddle.load(rng_file, return_numpy=True) |
| 2033 | + if checkpoint_rng_state.get("world_size", None) != self.args.world_size: |
| 2034 | + logger.warn("Cannot load rng states when changing world size of training job.") |
| 2035 | + return |
2052 | 2036 |
|
2053 | 2037 | random.setstate(checkpoint_rng_state["python"]) |
2054 | 2038 | np.random.set_state(checkpoint_rng_state["numpy"]) |
@@ -2210,11 +2194,6 @@ def _wrap_model(self, model, training=True): |
2210 | 2194 | else: |
2211 | 2195 | model, self.optimizer = decorated |
2212 | 2196 |
|
2213 | | - if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: |
2214 | | - register_sequence_parallel_allreduce_hooks( |
2215 | | - model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce |
2216 | | - ) |
2217 | | - |
2218 | 2197 | if self.args.world_size == 1: |
2219 | 2198 | if self.args.amp_master_grad: |
2220 | 2199 | mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) |
@@ -2403,6 +2382,17 @@ def get_expected_keys(inputs, keys): |
2403 | 2382 | ): |
2404 | 2383 | self.optimizer._set_broadcast_overlap(True, model) |
2405 | 2384 |
|
| 2385 | + # use callback for sp grad sync in case of unexpected behaviour (except sharding stage 2&3) |
| 2386 | + if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: |
| 2387 | + if ShardingOption.SHARD_GRAD_OP in self.args.sharding or ShardingOption.FULL_SHARD in self.args.sharding: |
| 2388 | + register_sequence_parallel_allreduce_hooks( |
| 2389 | + unwrap_model(model), |
| 2390 | + self.args.gradient_accumulation_steps, |
| 2391 | + self.args.fuse_sequence_parallel_allreduce, |
| 2392 | + ) |
| 2393 | + else: |
| 2394 | + self.add_callback(SPGradSyncCallback(model._layers)) |
| 2395 | + |
2406 | 2396 | return model |
2407 | 2397 |
|
2408 | 2398 | def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]: |
@@ -2739,28 +2729,24 @@ def _save_checkpoint(self, model, metrics=None): |
2739 | 2729 | if self.args.should_save: |
2740 | 2730 | self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) |
2741 | 2731 |
|
2742 | | - # Save RNG state in non-distributed training |
2743 | | - rng_states = { |
2744 | | - "python": random.getstate(), |
2745 | | - "numpy": np.random.get_state(), |
2746 | | - "cuda": paddle.get_rng_state(), |
2747 | | - "cpu": paddle.framework.core.default_cpu_generator().get_state(), |
2748 | | - } |
2749 | | - if self.args.use_hybrid_parallel: |
2750 | | - rng_states[ |
2751 | | - "hybrid_parallel_rng_state_tracker" |
2752 | | - ] = fleet.meta_parallel.get_rng_state_tracker().get_states_tracker() |
| 2732 | + if self.args.save_rng_states: |
| 2733 | + # Save RNG state in non-distributed training |
| 2734 | + rng_states = { |
| 2735 | + "python": random.getstate(), |
| 2736 | + "numpy": np.random.get_state(), |
| 2737 | + "cuda": paddle.get_rng_state(), |
| 2738 | + "cpu": paddle.framework.core.default_cpu_generator().get_state(), |
| 2739 | + "world_size": self.args.world_size, |
| 2740 | + } |
| 2741 | + if self.args.use_hybrid_parallel: |
| 2742 | + rng_states[ |
| 2743 | + "hybrid_parallel_rng_state_tracker" |
| 2744 | + ] = fleet.meta_parallel.get_rng_state_tracker().get_states_tracker() |
2753 | 2745 |
|
2754 | 2746 | if self.args.save_rng_states: |
2755 | | - if self.args.world_size > 1: |
2756 | | - rng_states_list = [] |
2757 | | - paddle.distributed.all_gather_object(rng_states_list, rng_states) |
2758 | | - if self.args.should_save: |
2759 | | - os.makedirs(output_dir, exist_ok=True) |
2760 | | - paddle.save(rng_states_list, os.path.join(output_dir, f"rng_state_{self.args.world_size}.pth")) |
2761 | | - else: |
2762 | | - os.makedirs(output_dir, exist_ok=True) |
2763 | | - paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth")) |
| 2747 | + rng_state_file = os.path.join(output_dir, f"rng_state_{dist.get_rank()}.pth") |
| 2748 | + os.makedirs(output_dir, exist_ok=True) |
| 2749 | + paddle.save(rng_states, rng_state_file) |
2764 | 2750 |
|
2765 | 2751 | # only save model state dict, ignore optimizer and scheduler |
2766 | 2752 | if not self.args.ignore_save_lr_and_optim: |
|
0 commit comments