Commit 66d05a2
Fix a gradient clipping bug for layer normalization layers with microbatch axes.
The previous code passed the unstacked gradients (a list) instead of the stacked gradients (a tensor) to the microbatcher, which led to unexpected behavior. This change passes the right argument and changes the original unit test to catch this bug.
PiperOrigin-RevId: 6694130641 parent b396397 commit 66d05a2
2 files changed
Lines changed: 8 additions & 3 deletions
File tree
- tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions
Lines changed: 4 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
80 | 80 | | |
81 | 81 | | |
82 | 82 | | |
83 | | - | |
| 83 | + | |
84 | 84 | | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
85 | 88 | | |
86 | 89 | | |
87 | 90 | | |
| |||
Lines changed: 4 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
137 | | - | |
| 137 | + | |
138 | 138 | | |
139 | 139 | | |
140 | 140 | | |
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
150 | | - | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
151 | 153 | | |
152 | 154 | | |
153 | 155 | | |
| |||
0 commit comments