diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 46fbb4ebaf05..8098e20d7044 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1292,7 +1292,7 @@ def _allocate_or_extend_buffers(self, idx, shape, dtype): self._grad_layer_buf[idx] = new_buf return self._grad_layer_buf[idx] else: - return self._grad_layer_buf[idx].flatten()[:numel].view(shape) + return self._grad_layer_buf[idx].flatten()[:numel].view(shape).to(dtype) def forward(self, *args, **kwargs): """Disabled for pipeline parallel training. See ``train_batch()``. """ diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0bfb18877f2d..e4dc531faa11 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -20,8 +20,8 @@ from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes -from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, - align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace, +from deepspeed.runtime.utils import (empty_cache, see_memory_usage, is_model_parallel_parameter, align_dense_tensors, + all_gather_dp_groups, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward) from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum @@ -1460,8 +1460,11 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): self.clear_grad_attribute(param) #offload only def complete_grad_norm_calculation_for_cpu_offload(self, params): - total_norm = 0.0 - norm_type = 2.0 + """ + Compute local squared L2 norm of gradients for CPU-offloaded parameters. + No cross-rank communication is performed here. + """ + local_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype) for p in params: # Pipeline parallelism may replicate parameters. Avoid multi-counting. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: @@ -1474,7 +1477,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): # so they have no norm_for_param_grads if param_id in self.norm_for_param_grads: param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + local_sq_norm += param_norm.item()**2 else: # As unused parameters in modules may not be expected sometimes, # add an explicit error msg when it occurred and an option to @@ -1487,19 +1490,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): (2) making sure all trainable parameters and `forward` function outputs participate in calculating loss. """ - - # Sum across all model parallel GPUs. - total_dev_norm = get_accelerator().FloatTensor([float(total_norm)]) - dist.all_reduce(total_dev_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_dev_norm, op=dist.ReduceOp.SUM) - - total_norm = total_dev_norm[0].item()**(1. / norm_type) - - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1.0 - - return torch.tensor(total_norm, device=self.device, dtype=torch.float) + return local_sq_norm ############################################################################################ def copy_grads_in_partition(self, param): @@ -1912,41 +1903,21 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): Returns: Total norm of the parameters (viewed as a single vector). """ - norm_type = float(norm_type) - all_norms = [] - if norm_type == inf: - for g in gradients: - all_norms.append(g.data.abs().max().float()) - total_norm = torch.stack(all_norms).max() - dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=self.dp_process_group) - - # Take max across all GPUs. - self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.MAX) - else: - # if dist.get_rank() == 0: - # logger.info(f"Total Norm beginning {total_norm}") - for g, p in zip(gradients, params): - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: - continue - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - all_norms.append( - torch.linalg.vector_norm(g.data.double().detach(), - ord=norm_type).to(get_accelerator().current_device_name())) - if len(all_norms) > 0: - total_norm = torch.stack(all_norms).square().sum().float() - else: - total_norm = torch.tensor(0.0, dtype=torch.float32).to(self.device) - # Sum across all model parallel Device. - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group) + assert norm_type == 2, "only L2 norm supported" - self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.SUM) + local_sq_norm = torch.zeros(1, device=self.device, dtype=torch.float32) - total_norm = total_norm.pow(1. / norm_type) + for g, p in zip(gradients, params): + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + continue - mask_nan_or_inf_with_val_inplace(total_norm, device=self.device) + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + if g is None: + continue + local_sq_norm += torch.sum(g.data.double() * g.data.double()) - return total_norm + return local_sq_norm def get_all_grad_tensors(self, tensor_list, dtype): all_grad_tensors = [] @@ -2055,19 +2026,48 @@ def override_loss_scale(self, loss_scale): def scaled_global_norm(self, norm_type=2): assert norm_type == 2, "only L2 norm supported" - norm_groups = [] - for i, group in enumerate(self.bit16_groups): + # Collect per-parameter-group squared-norms so MoE averaging can + # operate on a per-group basis instead of a single accumulated value. + group_sq_norms = [] + local_total_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype) + for i, _ in enumerate(self.bit16_groups): if self.cpu_offload: - norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]) - norm_groups.append(norm) + group_sq_norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]) else: - norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) + group_sq_norm = self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i]) + group_sq_norms.append(group_sq_norm) + local_total_sq_norm += group_sq_norm if self.has_moe_layers: - self._average_expert_grad_norms(norm_groups) + # _average_expert_grad_norms expects an indexable collection of per-group norms + # and updates them in-place for MoE groups. Pass the list instead of the + # single accumulated tensor so expert and non-expert groups are scaled + # correctly prior to global reduction. + self._average_expert_grad_norms(group_sq_norms) + + # Recompute the total from possibly-updated per-group norms to reflect + # any MoE-specific averaging that occurred. + local_total_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype) + for g in group_sq_norms: + # ensure device/dtype compatibility when summing + local_total_sq_norm += g.to(local_total_sq_norm.device) + + # Move tensor to the current accelerator device (supports non-CUDA backends) + local_total_sq_norm = local_total_sq_norm.to(get_accelerator().current_device_name()) + dist.all_reduce( + local_total_sq_norm, + op=dist.ReduceOp.SUM, + group=self.dp_process_group, + ) + self._model_parallel_all_reduce( + tensor=local_total_sq_norm, + op=dist.ReduceOp.SUM, + ) + total_norm = torch.sqrt(local_total_sq_norm) # calculating L2 norm - return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type) + mask_nan_or_inf_with_val_inplace(total_norm, device=self.device) + return total_norm def get_bit16_param_group(self, group_no): bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]