Skip to content

Commit 33fc0c9

Browse files
authored
fix(autojac): Rename internal variable (#282)
* In mtl_backward.py, rename the parameter tasks_params of _create_task_transform to task_params, as it represents the parameters of a single task
1 parent 1f0afca commit 33fc0c9

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/torchjd/autojac/mtl_backward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,12 @@ def _create_transform(
160160

161161
def _create_task_transform(
162162
features: list[Tensor],
163-
tasks_params: list[Tensor],
163+
task_params: list[Tensor],
164164
loss: Tensor,
165165
retain_graph: bool,
166166
) -> Transform[EmptyTensorDict, Gradients]:
167167
# Tensors with respect to which we compute the gradients.
168-
to_differentiate = tasks_params + features
168+
to_differentiate = task_params + features
169169

170170
# Transform that initializes the gradient output to 1.
171171
init = Init([loss])
@@ -176,7 +176,7 @@ def _create_task_transform(
176176

177177
# Transform that accumulates the gradients w.r.t. the task-specific parameters into their
178178
# .grad fields.
179-
accumulate = Accumulate(tasks_params) << Select(tasks_params, to_differentiate)
179+
accumulate = Accumulate(task_params) << Select(task_params, to_differentiate)
180180

181181
# Transform that backpropagates the gradients of the losses w.r.t. the features.
182182
backpropagate = Select(features, to_differentiate)

0 commit comments

Comments
 (0)