You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In order to make torchjd as efficient as possible, I think we need to limit ourselves to only checking the inputs provided by the user, at the interface (e.g. are the losses provided to mtl_backward correct?), or things that are unknown before runtime (e.g. are the values of the Jacobian finite?). A few examples of checks that we either have (or used to have - in which case the box is ticked), that reduce performance while never failing, are:
Key checks when composing transforms => This always passed in publicly available transforms (i.e. those created by backward and mtl_backward). Fixed in Isolate key checks in transforms #279.
_AggregateMatrices._disunite checks that the united_gradient_vector has the expected length => This always passes. Fix proposed in Remove _disunite check #377.
TensorDict checks => These checks also always pass when the transforms are correctly implemented (which a test should check anyway). Fix proposed in Move tensor dict checks out of constructor #380.
The aggregators check that their inputs are finite (i.e. no nan, inf or -inf values allowed). Technically, we can't really know ahead of time if the Jacobian will be finite, this check is not useless. However, it can be very time or memory-consuming on huge matrices (and the Jacobian can be huge). It takes 1.62 sec on average on my cpu when testing a matrix of shape [10, 10^8]. On cuda, it's even worse: it requires the instantiation of another [10, 10^8] matrix (to contain the boolean values), which is too much memory for my gpu to handle. So I think we need to address this. Fix proposed in Remove _check_is_finite #381.
A few extra cases that I think are not worth changing (so I tick them already):
The checks of get_leaf_tensors (e.g. if any([tensor.grad_fn is None for tensor in tensors]):) could technically be moved from this function to backward and mtl_backward). However, it does not bring any performance and it makes the code less readable IMO.
The aggregators check that their input is a matrix in Aggregator._check_is_matrix. In autojac, it will always be the case, but aggregators are also publicly available so users can use them on their own tensors (they're thus another interface to torchjd). In this case, the tensors might not be matrices, and it's important to raise a clear error. Also, performance-wise, this results in 1 extra trivial check per backward pass.
_check_expects_grad could technically be moved from the Accumulate transform to the backward and mtl_backward function (to be at the interface). In case the tensors with respect to which we differentiate are not provided, our get_leaf_tensors function should already provide tensors that expect a grad, so the check could be skipped. Otherwise, we could simply check that all the tensors provided by the user do expect a grad. This would only lead to a (negligible) performance improvement in the first case (as otherwise, we would still run the same number of checks, just at a different place), but since it's the default, it's very likely to happen. Also, it goes in the direction of having checks only at the interface. However, I think we decided that the requires_grad, is_leaf and retains_grad fields might change between the creation of the Transform and the time it is applied, so the current version is safer in that respect. For that reason (and the fact the performance improvement would be negligible), I think it's fine to leave Accumulate as it is.
In order to make torchjd as efficient as possible, I think we need to limit ourselves to only checking the inputs provided by the user, at the interface (e.g. are the
lossesprovided tomtl_backwardcorrect?), or things that are unknown before runtime (e.g. are the values of the Jacobian finite?). A few examples of checks that we either have (or used to have - in which case the box is ticked), that reduce performance while never failing, are:backwardandmtl_backward). Fixed in Isolate key checks in transforms #279._AggregateMatrices._disunitechecks that theunited_gradient_vectorhas the expected length => This always passes. Fix proposed in Remove_disunitecheck #377.TensorDictchecks => These checks also always pass when the transforms are correctly implemented (which a test should check anyway). Fix proposed in Move tensor dict checks out of constructor #380.nan,infor-infvalues allowed). Technically, we can't really know ahead of time if the Jacobian will be finite, this check is not useless. However, it can be very time or memory-consuming on huge matrices (and the Jacobian can be huge). It takes 1.62 sec on average on my cpu when testing a matrix of shape [10, 10^8]. On cuda, it's even worse: it requires the instantiation of another [10, 10^8] matrix (to contain the boolean values), which is too much memory for my gpu to handle. So I think we need to address this. Fix proposed in Remove_check_is_finite#381.A few extra cases that I think are not worth changing (so I tick them already):
get_leaf_tensors(e.g.if any([tensor.grad_fn is None for tensor in tensors]):) could technically be moved from this function tobackwardandmtl_backward). However, it does not bring any performance and it makes the code less readable IMO.Aggregator._check_is_matrix. In autojac, it will always be the case, but aggregators are also publicly available so users can use them on their own tensors (they're thus another interface totorchjd). In this case, the tensors might not be matrices, and it's important to raise a clear error. Also, performance-wise, this results in 1 extra trivial check per backward pass._check_expects_gradcould technically be moved from theAccumulatetransform to thebackwardandmtl_backwardfunction (to be at the interface). In case the tensors with respect to which we differentiate are not provided, ourget_leaf_tensorsfunction should already provide tensors that expect a grad, so the check could be skipped. Otherwise, we could simply check that all the tensors provided by the user do expect a grad. This would only lead to a (negligible) performance improvement in the first case (as otherwise, we would still run the same number of checks, just at a different place), but since it's the default, it's very likely to happen. Also, it goes in the direction of having checks only at the interface. However, I think we decided that therequires_grad,is_leafandretains_gradfields might change between the creation of the Transform and the time it is applied, so the current version is safer in that respect. For that reason (and the fact the performance improvement would be negligible), I think it's fine to leaveAccumulateas it is.