File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -160,12 +160,12 @@ def _create_transform(
160160
161161def _create_task_transform (
162162 features : list [Tensor ],
163- tasks_params : list [Tensor ],
163+ task_params : list [Tensor ],
164164 loss : Tensor ,
165165 retain_graph : bool ,
166166) -> Transform [EmptyTensorDict , Gradients ]:
167167 # Tensors with respect to which we compute the gradients.
168- to_differentiate = tasks_params + features
168+ to_differentiate = task_params + features
169169
170170 # Transform that initializes the gradient output to 1.
171171 init = Init ([loss ])
@@ -176,7 +176,7 @@ def _create_task_transform(
176176
177177 # Transform that accumulates the gradients w.r.t. the task-specific parameters into their
178178 # .grad fields.
179- accumulate = Accumulate (tasks_params ) << Select (tasks_params , to_differentiate )
179+ accumulate = Accumulate (task_params ) << Select (task_params , to_differentiate )
180180
181181 # Transform that backpropagates the gradients of the losses w.r.t. the features.
182182 backpropagate = Select (features , to_differentiate )
You can’t perform that action at this time.
0 commit comments