From aee71bbda112727829552f036883afce1b6fc40c Mon Sep 17 00:00:00 2001 From: Brianna Mueller Date: Mon, 30 Mar 2026 23:18:57 +0000 Subject: [PATCH] Fix FedKD SVD decomposition for >2D parameters (e.g. Conv2d) Flatten >2D parameters to 2D before SVD and reshape back on recovery. The previous transpose-based approach produced incorrect shapes that caused broadcast errors during parameter reconstruction. --- system/flcore/clients/clientkd.py | 23 ++++++++++------------- system/flcore/servers/serverkd.py | 25 ++++++++++--------------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/system/flcore/clients/clientkd.py b/system/flcore/clients/clientkd.py index 39ec75ed..48b7908a 100644 --- a/system/flcore/clients/clientkd.py +++ b/system/flcore/clients/clientkd.py @@ -101,8 +101,11 @@ def set_parameters(self, global_param, energy): # recover for k in global_param.keys(): if len(global_param[k]) == 3: - # use np.matmul to support high-dimensional CNN param - global_param[k] = np.matmul(global_param[k][0] * global_param[k][1][..., None, :], global_param[k][2]) + u, sigma, v = global_param[k] + recovered = np.matmul(u * sigma[..., None, :], v) + # reshape back to original dimensions for >2D params (e.g. Conv2d) + orig_shape = self.global_model.state_dict()[k].shape + global_param[k] = recovered.reshape(orig_shape) for name, old_param in self.global_model.named_parameters(): if name in global_param: @@ -150,15 +153,14 @@ def decomposition(self): param_cpu = param.detach().cpu().numpy() # refer to https://github.com/wuch15/FedKD/blob/main/run.py#L187 if param_cpu.shape[0]>1 and len(param_cpu.shape)>1 and 'embeddings' not in name: + orig_shape = param_cpu.shape + # flatten >2D params (e.g. Conv2d) to 2D for proper SVD + if len(orig_shape) > 2: + param_cpu = param_cpu.reshape(orig_shape[0], -1) u, sigma, v = np.linalg.svd(param_cpu, full_matrices=False) - # support high-dimensional CNN param - if len(u.shape)==4: - u = np.transpose(u, (2, 3, 0, 1)) - sigma = np.transpose(sigma, (2, 0, 1)) - v = np.transpose(v, (2, 3, 0, 1)) threshold=0 if np.sum(np.square(sigma))==0: - compressed_param_cpu=param_cpu + compressed_param_cpu=param_cpu.reshape(orig_shape) else: for singular_value_num in range(len(sigma)): if np.sum(np.square(sigma[:singular_value_num]))>self.energy*np.sum(np.square(sigma)): @@ -167,11 +169,6 @@ def decomposition(self): u=u[:, :threshold] sigma=sigma[:threshold] v=v[:threshold, :] - # support high-dimensional CNN param - if len(u.shape)==4: - u = np.transpose(u, (2, 3, 0, 1)) - sigma = np.transpose(sigma, (1, 2, 0)) - v = np.transpose(v, (2, 3, 0, 1)) compressed_param_cpu=[u,sigma,v] elif 'embeddings' not in name: compressed_param_cpu=param_cpu diff --git a/system/flcore/servers/serverkd.py b/system/flcore/servers/serverkd.py index 2da5fd63..481fd524 100644 --- a/system/flcore/servers/serverkd.py +++ b/system/flcore/servers/serverkd.py @@ -107,10 +107,11 @@ def receive_models(self): # recover for k in client.compressed_param.keys(): if len(client.compressed_param[k]) == 3: - # use np.matmul to support high-dimensional CNN param - client.compressed_param[k] = np.matmul( - client.compressed_param[k][0] * client.compressed_param[k][1][..., None, :], - client.compressed_param[k][2]) + u, sigma, v = client.compressed_param[k] + recovered = np.matmul(u * sigma[..., None, :], v) + # reshape back to original dimensions for >2D params (e.g. Conv2d) + orig_shape = client.global_model.state_dict()[k].shape + client.compressed_param[k] = recovered.reshape(orig_shape) self.uploaded_models.append(client.compressed_param) @@ -134,15 +135,14 @@ def decomposition(self): for name, param_cpu in self.global_model.items(): # refer to https://github.com/wuch15/FedKD/blob/main/run.py#L187 if param_cpu.shape[0]>1 and len(param_cpu.shape)>1 and 'embeddings' not in name: + orig_shape = param_cpu.shape + # flatten >2D params (e.g. Conv2d) to 2D for proper SVD + if len(orig_shape) > 2: + param_cpu = param_cpu.reshape(orig_shape[0], -1) u, sigma, v = np.linalg.svd(param_cpu, full_matrices=False) - # support high-dimensional CNN param - if len(u.shape)==4: - u = np.transpose(u, (2, 3, 0, 1)) - sigma = np.transpose(sigma, (2, 0, 1)) - v = np.transpose(v, (2, 3, 0, 1)) threshold=0 if np.sum(np.square(sigma))==0: - compressed_param_cpu=param_cpu + compressed_param_cpu=param_cpu.reshape(orig_shape) else: for singular_value_num in range(len(sigma)): if np.sum(np.square(sigma[:singular_value_num]))>self.energy*np.sum(np.square(sigma)): @@ -151,11 +151,6 @@ def decomposition(self): u=u[:,:threshold] sigma=sigma[:threshold] v=v[:threshold,:] - # support high-dimensional CNN param - if len(u.shape)==4: - u = np.transpose(u, (2, 3, 0, 1)) - sigma = np.transpose(sigma, (1, 2, 0)) - v = np.transpose(v, (2, 3, 0, 1)) compressed_param_cpu=[u,sigma,v] elif 'embeddings' not in name: compressed_param_cpu=param_cpu