-
Notifications
You must be signed in to change notification settings - Fork 491
Fix #1040: [Feature] Fixed bugs in Archon LoRA Backend #1139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,11 +99,22 @@ def get_grad_norm_fp32( | |
| norm_type = float(norm_type) | ||
| total_norm = 0.0 | ||
|
|
||
| if not grads_for_norm: | ||
| return 0.0 | ||
|
|
||
| device = current_platform.current_device() | ||
|
|
||
| if not grads_for_norm: | ||
| # Still participate in all_reduce with zero contribution so that | ||
| # ranks with grads don't hang waiting for this rank (e.g. LoRA frozen ranks). | ||
| total_norm_cuda = torch.tensor(0.0, dtype=torch.float, device=device) | ||
| reduce_op = dist.ReduceOp.MAX if norm_type == torch.inf else dist.ReduceOp.SUM | ||
| if data_parallel_group: | ||
| dist.all_reduce(total_norm_cuda, op=reduce_op, group=data_parallel_group) | ||
| if model_parallel_group is not None: | ||
| dist.all_reduce(total_norm_cuda, op=reduce_op, group=model_parallel_group) | ||
| total_norm = float(total_norm_cuda.item()) | ||
| if norm_type != torch.inf and total_norm > 0: | ||
| total_norm = total_norm ** (1.0 / norm_type) | ||
| return total_norm | ||
|
Comment on lines
+104
to
+116
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for participating in the Consider refactoring the function to unify the synchronization and post-processing steps. For instance, you could initialize |
||
|
|
||
| if norm_type == torch.inf: | ||
| norms = [grad.abs().max() for grad in grads_for_norm] | ||
| total_norm = torch.max(torch.stack(norms)) if norms else 0.0 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an inconsistency in how the existence of process groups is checked.
data_parallel_groupis checked using truthiness, whilemodel_parallel_groupis checked using an explicitis not Nonecomparison. Per PEP 8, it is generally preferred to useis not Nonewhen checking if an optional argument was provided, to avoid potential issues with objects that might define a custom__bool__method. Using a consistent check also improves readability.References