@@ -253,8 +253,8 @@ <h1>Monitoring aggregations<a class="headerlink" href="#monitoring-aggregations"
253253they have a negative inner product).</ p >
254254< div class ="highlight-python notranslate "> < div class ="highlight "> < pre > < span > </ span > < span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> torch</ span >
255255< span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> torch.nn</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> Linear</ span > < span class ="p "> ,</ span > < span class ="n "> MSELoss</ span > < span class ="p "> ,</ span > < span class ="n "> ReLU</ span > < span class ="p "> ,</ span > < span class ="n "> Sequential</ span >
256- < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> torch.optim</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> SGD</ span >
257256< span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> torch.nn.functional</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> cosine_similarity</ span >
257+ < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> torch.optim</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> SGD</ span >
258258
259259< span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> torchjd</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> mtl_backward</ span >
260260< span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> torchjd.aggregation</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> UPGrad</ span >
@@ -263,7 +263,7 @@ <h1>Monitoring aggregations<a class="headerlink" href="#monitoring-aggregations"
263263</ span > < span class ="hll "> < span class ="w "> </ span > < span class ="sd "> """Prints the extracted weights."""</ span >
264264</ span > < span class ="hll "> < span class ="nb "> print</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "Weights: </ span > < span class ="si "> {</ span > < span class ="n "> weights</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
265265</ span >
266- < span class ="hll "> < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> print_similarity_with_gd </ span > < span class ="p "> (</ span > < span class ="n "> _</ span > < span class ="p "> ,</ span > < span class ="n "> inputs</ span > < span class ="p "> :</ span > < span class ="nb "> tuple</ span > < span class ="p "> [</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ],</ span > < span class ="n "> aggregation</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
266+ < span class ="hll "> < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> print_gd_similarity </ span > < span class ="p "> (</ span > < span class ="n "> _</ span > < span class ="p "> ,</ span > < span class ="n "> inputs</ span > < span class ="p "> :</ span > < span class ="nb "> tuple</ span > < span class ="p "> [</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class =" p " > , </ span > < span class =" o " > ... </ span > < span class ="p "> ],</ span > < span class ="n "> aggregation</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
267267</ span > < span class ="hll "> < span class ="w "> </ span > < span class ="sd "> """Prints the cosine similarity between the aggregation and the average gradient."""</ span >
268268</ span > < span class ="hll "> < span class ="n "> matrix</ span > < span class ="o "> =</ span > < span class ="n "> inputs</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span >
269269</ span > < span class ="hll "> < span class ="n "> gd_output</ span > < span class ="o "> =</ span > < span class ="n "> matrix</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> (</ span > < span class ="n "> dim</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span > < span class ="p "> )</ span >
@@ -284,7 +284,7 @@ <h1>Monitoring aggregations<a class="headerlink" href="#monitoring-aggregations"
284284< span class ="n "> aggregator</ span > < span class ="o "> =</ span > < span class ="n "> UPGrad</ span > < span class ="p "> ()</ span >
285285
286286< span class ="hll "> < span class ="n "> aggregator</ span > < span class ="o "> .</ span > < span class ="n "> weighting</ span > < span class ="o "> .</ span > < span class ="n "> register_forward_hook</ span > < span class ="p "> (</ span > < span class ="n "> print_weights</ span > < span class ="p "> )</ span >
287- </ span > < span class ="hll "> < span class ="n "> aggregator</ span > < span class ="o "> .</ span > < span class ="n "> register_forward_hook</ span > < span class ="p "> (</ span > < span class ="n "> print_similarity_with_gd </ span > < span class ="p "> )</ span >
287+ </ span > < span class ="hll "> < span class ="n "> aggregator</ span > < span class ="o "> .</ span > < span class ="n "> register_forward_hook</ span > < span class ="p "> (</ span > < span class ="n "> print_gd_similarity </ span > < span class ="p "> )</ span >
288288</ span >
289289< span class ="n "> inputs</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> randn</ span > < span class ="p "> (</ span > < span class ="mi "> 8</ span > < span class ="p "> ,</ span > < span class ="mi "> 16</ span > < span class ="p "> ,</ span > < span class ="mi "> 10</ span > < span class ="p "> )</ span > < span class ="c1 "> # 8 batches of 16 random input vectors of length 10</ span >
290290< span class ="n "> task1_targets</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> randn</ span > < span class ="p "> (</ span > < span class ="mi "> 8</ span > < span class ="p "> ,</ span > < span class ="mi "> 16</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span > < span class ="c1 "> # 8 batches of 16 targets for the first task</ span >
0 commit comments