@@ -18,8 +18,8 @@ def cosine_metric2(
1818
1919 for layer in model1 :
2020 if layer in model2 :
21- l1 = model1 [layer ].flatten ()
22- l2 = model2 [layer ].flatten ()
21+ l1 = model1 [layer ].detach (). to ( "cpu" ). flatten ()
22+ l2 = model2 [layer ].detach (). to ( "cpu" ). flatten ()
2323 if l1 .shape != l2 .shape :
2424 # Adjust the shape of the smaller layer to match the larger layer
2525 min_len = min (l1 .shape [0 ], l2 .shape [0 ])
@@ -29,9 +29,7 @@ def cosine_metric2(
2929 cos_similarities .append (cos_sim .item ())
3030
3131 if cos_similarities :
32- avg_cos_sim = torch .mean (torch .tensor (cos_similarities ))
33- # result = torch.clamp(avg_cos_sim, min=0).item()
34- # return result
32+ avg_cos_sim = torch .mean (torch .tensor (cos_similarities , device = "cpu" ))
3533 return avg_cos_sim .item () if similarity else (1 - avg_cos_sim .item ())
3634 else :
3735 return None
@@ -80,8 +78,8 @@ def euclidean_metric(
8078
8179 for layer in model1 :
8280 if layer in model2 :
83- l1 = model1 [layer ].flatten ().float ()
84- l2 = model2 [layer ].flatten ().float ()
81+ l1 = model1 [layer ].detach (). to ( "cpu" ). flatten ().float ()
82+ l2 = model2 [layer ].detach (). to ( "cpu" ). flatten ().float ()
8583
8684 if standardized :
8785 std_l1 , std_l2 = l1 .std (), l2 .std ()
@@ -100,7 +98,7 @@ def euclidean_metric(
10098 distances .append (distance .item ())
10199
102100 if distances :
103- avg_distance = torch .mean (torch .tensor (distances , dtype = torch .float32 ))
101+ avg_distance = torch .mean (torch .tensor (distances , dtype = torch .float32 , device = "cpu" ))
104102 return avg_distance .item () if not torch .isnan (avg_distance ) else 0.0
105103 else :
106104 return None
@@ -119,8 +117,8 @@ def minkowski_metric(
119117
120118 for layer in model1 :
121119 if layer in model2 :
122- l1 = model1 [layer ].flatten ().float ()
123- l2 = model2 [layer ].flatten ().float ()
120+ l1 = model1 [layer ].detach (). to ( "cpu" ). flatten ().float ()
121+ l2 = model2 [layer ].detach (). to ( "cpu" ). flatten ().float ()
124122
125123 distance = torch .norm (l1 - l2 , p = p )
126124 if similarity :
@@ -131,7 +129,7 @@ def minkowski_metric(
131129 distances .append (distance .item ())
132130
133131 if distances :
134- avg_distance = torch .mean (torch .tensor (distances , dtype = torch .float32 ))
132+ avg_distance = torch .mean (torch .tensor (distances , dtype = torch .float32 , device = "cpu" ))
135133 return avg_distance .item () if not torch .isnan (avg_distance ) else 0.0
136134 else :
137135 return None
@@ -149,8 +147,8 @@ def manhattan_metric(
149147
150148 for layer in model1 :
151149 if layer in model2 :
152- l1 = model1 [layer ].flatten ().float ()
153- l2 = model2 [layer ].flatten ().float ()
150+ l1 = model1 [layer ].detach (). to ( "cpu" ). flatten ().float ()
151+ l2 = model2 [layer ].detach (). to ( "cpu" ). flatten ().float ()
154152
155153 distance = torch .norm (l1 - l2 , p = 1 )
156154 if similarity :
@@ -161,7 +159,7 @@ def manhattan_metric(
161159 distances .append (distance .item ())
162160
163161 if distances :
164- avg_distance = torch .mean (torch .tensor (distances , dtype = torch .float32 ))
162+ avg_distance = torch .mean (torch .tensor (distances , dtype = torch .float32 , device = "cpu" ))
165163 return avg_distance .item ()
166164 else :
167165 return None
@@ -179,8 +177,8 @@ def pearson_correlation_metric(
179177
180178 for layer in model1 :
181179 if layer in model2 :
182- l1 = model1 [layer ].flatten ().float ()
183- l2 = model2 [layer ].flatten ().float ()
180+ l1 = model1 [layer ].detach (). to ( "cpu" ). flatten ().float ()
181+ l2 = model2 [layer ].detach (). to ( "cpu" ). flatten ().float ()
184182
185183 if l1 .shape != l2 .shape :
186184 min_len = min (l1 .shape [0 ], l2 .shape [0 ])
@@ -219,8 +217,8 @@ def jaccard_metric(
219217
220218 for layer in model1 :
221219 if layer in model2 :
222- l1 = model1 [layer ].flatten ().float ()
223- l2 = model2 [layer ].flatten ().float ()
220+ l1 = model1 [layer ].detach (). to ( "cpu" ). flatten ().float ()
221+ l2 = model2 [layer ].detach (). to ( "cpu" ). flatten ().float ()
224222
225223 intersection = torch .sum (torch .min (l1 , l2 ))
226224 union = torch .sum (torch .max (l1 , l2 ))
@@ -232,24 +230,23 @@ def jaccard_metric(
232230 jaccard_scores .append (1 - jaccard_sim .item ())
233231
234232 if jaccard_scores :
235- avg_jaccard = torch .mean (torch .tensor (jaccard_scores , dtype = torch .float32 ))
233+ avg_jaccard = torch .mean (torch .tensor (jaccard_scores , dtype = torch .float32 , device = "cpu" ))
236234 return avg_jaccard .item ()
237235 else :
238236 return None
239237
240238
241239def normalise_layers (untrusted_params , trusted_params ):
242- trusted_norms = dict ([k , torch .norm (trusted_params [k ].data .view (- 1 ).float ())] for k in trusted_params .keys ())
240+ trusted_norms = dict ([k , torch .norm (trusted_params [k ].data .to ( "cpu" ). view (- 1 ).float ())] for k in trusted_params .keys ())
243241
244242 normalised_params = copy .deepcopy (untrusted_params )
245243
246244 state_dict = copy .deepcopy (untrusted_params )
247245 for layer in untrusted_params :
248- layer_norm = torch .norm (state_dict [layer ].data .view (- 1 ).float ())
246+ layer_norm = torch .norm (state_dict [layer ].data .to ( "cpu" ). view (- 1 ).float ())
249247 scaling_factor = min (layer_norm / trusted_norms [layer ], 1 )
250248 logging .debug (f"Layer: { layer } ScalingFactor { scaling_factor } " )
251- # logging.info("Scaling client {} layer {} with factor {}".format(client, layer, scaling_factor))
252- normalised_layer = torch .mul (state_dict [layer ], scaling_factor )
249+ normalised_layer = torch .mul (state_dict [layer ].to ("cpu" ), scaling_factor )
253250 normalised_params [layer ] = normalised_layer
254251
255252 return normalised_params
0 commit comments