-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver_DANN.py.old
More file actions
405 lines (353 loc) · 18.7 KB
/
server_DANN.py.old
File metadata and controls
405 lines (353 loc) · 18.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
from multiprocessing import context
import torch
import torch.optim as optim
import os
import time
from model import BrainCancer
from model_feature_regress import DANN3D, BrainCancerFeaturizer, BrainCancerRegressor
from utils import get_layer_params_list, get_layer_params_dict, flatten_layer_param_list_for_model, reconstruct_layer_from_flat
from utils import debug_function, log_print
import csv
from discriminator import ParamDiscriminator
from typing import Dict, List
class Server:
@debug_function(context="SERVER")
def __init__(self, num_clients, val_dataset, test_dataset, model_type="Normal", alpha_var=0.1, beta_sparsity=0.01, run_id=None, num_rounds=10):
"""
num_clients: number of clients in the federation
alpha_var: regularization weight for variance minimization in the M_inv block
beta_sparsity: regularization weight for L1 (sparsity) in the M_spec block
default: alpha_var=0.1, beta_sparsity=0.1
new 1: alpha_var=0.1, beta_sparsity=0.01
"""
# hook = sy.TorchHook(torch) # Hook PyTorch
self.run_id = run_id
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.num_clients = num_clients
self.alpha_var = alpha_var
self.beta_sparsity = beta_sparsity
self.num_rounds = num_rounds
self.domains = []
if model_type == "Normal":
self.initial_dummy_model = BrainCancer()
self.initial_dummy_paramters_dict = get_layer_params_dict(self.initial_dummy_model)
self.initial_dummy_paramters_list = get_layer_params_list(self.initial_dummy_model)
elif model_type == "DANN3D":
n_domains = 5
feat_net = BrainCancerFeaturizer(use_conv5=True) # or False for conv4
reg_head = BrainCancerRegressor()
self.initial_dummy_model = DANN3D(feat_net, reg_head, n_domains)
self.initial_dummy_paramters_dict = get_layer_params_dict(self.initial_dummy_model)
self.initial_dummy_paramters_list = get_layer_params_list(self.initial_dummy_model)
self.num_layers = len(self.initial_dummy_paramters_list)
self.d_total = sum(p.numel() for p in self.initial_dummy_paramters_list)
self.disc = ParamDiscriminator(self.d_total, num_clients)
self.opt_disc = optim.Adam(self.disc.parameters(), lr=1e-3)
self.ce_loss = torch.nn.CrossEntropyLoss()
self.M_spec_layer_wise = {}
self.prev_spec_global = {} # {layer_idx: Tensor [d_l]} for each layer
self.inv_agg = {}
self.SERVER_LOG_HEADERS = [
"round",
"layer_idx",
"disc_loss",
"inv_norm",
"spec_norm",
"agg_norm",
"param_diversity",
]
# self.vms = []
self.client_log_file_paths = []
# Generate a unique run directory (create if doesn't exist)
if self.run_id is None:
self.run_id = time.strftime("%Y-%m-%d_%H-%M-%S") # Timestamp-based run ID
self.run_dir = os.path.join("./runs", f"run_{run_id}") # Each run has a separate directory
os.makedirs(self.run_dir, exist_ok=True)
self.server_log_dir = os.path.join(self.run_dir, "server_log")
os.makedirs(self.server_log_dir, exist_ok=True) # Create directory if it doesn't exist
self.server_log_path = os.path.join(self.server_log_dir, "server_metrics.csv")
self.validation_log_path = os.path.join(self.server_log_dir, "validation_metrics.txt")
self.text_log_path = os.path.join(self.server_log_dir, "test_metrics.txt")
self.initialize_server_logger()
# Create virtual machines for each client
for i in range(num_clients):
# self.vms[i] = sy.VirtualMachine(name="domain_{i}")
# self.domains[i] = self.vms[i].get_root_client()
# Define file path for logging
self.client_log_dir = os.path.join(self.run_dir, "clients_log")
os.makedirs(self.client_log_dir, exist_ok=True) # Create directory if it doesn't exist
self.client_log_file_paths.append(os.path.join(self.client_log_dir, f"client_{i}_metrics"))
# Optional: store aggregated domain-invariant params for reference
# { layer_index: [ aggregated_invariant_vector ] }
self.global_invariant_store = {}
def initialize_server_logger(self):
# os.makedirs(self.server_log_dir, exist_ok=True)
if not os.path.exists(self.server_log_path):
with open(self.server_log_path, mode="w", newline="") as f:
writer = csv.writer(f)
writer.writerow(self.SERVER_LOG_HEADERS)
def log_server_metrics(self, row_dict):
with open(self.server_log_path, mode="a", newline="") as f:
writer = csv.writer(f)
writer.writerow([row_dict[h] for h in self.SERVER_LOG_HEADERS])
# --- new helper -------------------------------------------------------------
@debug_function(context="SERVER")
def train_discriminator(self, client_params_dict):
"""One-step discriminator update using the current clients’ weights."""
θ_batch = []
y_batch = []
for cid, layer_list in client_params_dict.items():
flat = torch.cat([p.reshape(-1) for p in layer_list]) # (d_total,)
θ_batch.append(flat)
y_batch.append(cid) # domain ≡ cid
θ_batch = torch.stack(θ_batch).cuda() # (K, d_total)
y_batch = torch.tensor(y_batch, dtype=torch.long).cuda()
self.disc.train(); self.opt_disc.zero_grad()
logits = self.disc(θ_batch)
loss = self.ce_loss(logits, y_batch)
loss.backward(); self.opt_disc.step()
return float(loss.detach().cpu())
# ---------------------------------------------------------------------------
# --- new helper -------------------------------------------------------------
@debug_function(context="SERVER")
@torch.no_grad()
def build_masks(
self,
layer_idx: int,
client_params_dict: Dict[int, List[torch.Tensor]],
param_matrix: torch.Tensor,
spec_frac: float = 0.05,
):
"""
Return a boolean 1‑D mask (length = d_layer) that marks which rows go into
M_spec for *this* layer.
layer_idx – index of the layer we are masking
client_params_dict – {cid: [layer0, layer1, …]} (full params per client)
param_matrix – stack of that same layer for every client (d, K)
spec_frac – fraction of rows to treat as domain‑specific
"""
self.disc.eval()
all_cids = list(client_params_dict.keys())
K = len(all_cids)
d_layer = param_matrix.size(0)
# Accumulate saliency scores for each row over all clients
saliency = torch.zeros(d_layer) # will hold mean |grad| per row
for j, cid in enumerate(all_cids):
# ---- 1. build *full* flattened vector θ_full for this client --------
θ_full = torch.cat([
torch.cat([p.reshape(-1) for p in client_params_dict[cid][l]])
for l in range(self.num_layers)
]).cuda() # shape (d_total,)
θ_full.requires_grad_(True)
# ---- 2. forward & backward through discriminator --------------------
logits = self.disc(θ_full.unsqueeze(0)) # (1, n_domains)
loss = self.ce_loss(logits, torch.tensor([cid], device=θ_full.device))
loss.backward()
# ---- 3. slice out the gradient rows that correspond to this layer ---
start, stop = self.layer_slices[layer_idx].start, self.layer_slices[layer_idx].stop
grad_rows = θ_full.grad[start:stop].abs().cpu() # (d_layer,)
saliency += grad_rows
self.disc.zero_grad()
saliency /= K # mean |grad| across clients
# ---- 4. choose top‑q% rows as domain‑specific ---------------------------
k = max(1, int(spec_frac * d_layer))
threshold = saliency.topk(k).values.min()
mask_spec = saliency >= threshold # True → goes to M_spec
return mask_spec # torch.BoolTensor length = d_layer
@debug_function(context="SERVER")
def alpha_mix(self, M_spec: torch.Tensor, layer_idx: int, alpha: float = 0.8):
"""
Blend each client's M_spec[:, j] with previous global spec.
- M_spec shape: [d, K]
- prev_spec_global[layer_idx]: shape [d]
- Returns: blended_spec [d, K]
"""
d, K = M_spec.shape
if layer_idx not in self.prev_spec_global:
# If this is the first round, fallback to current mean
self.prev_spec_global[layer_idx] = M_spec.mean(dim=1).clone().detach()
global_spec = self.prev_spec_global[layer_idx].unsqueeze(1) # [d, 1]
blended_spec = alpha * M_spec + (1 - alpha) * global_spec
# Update for next round: new global spec = weighted avg of current
self.prev_spec_global[layer_idx] = blended_spec.mean(dim=1).detach()
return blended_spec
@debug_function(context="SERVER")
def gather_client_params(self, client_params_dict, layer_idx):
"""
client_params_dict: { client_id: [layer0_tensor, layer1_tensor, ...] }
layer_idx: which layer to gather
Returns a matrix param_matrix of shape (d_layer, K), where d_layer is
the flattened dimension of this layer, K = num_clients.
"""
# for client_id, layer_params in client_params_dict.items():
# log_print(f"[DEBUG] Client {client_id} → Layer {layer_idx} total layers: {len(layer_params)}")
# for i, layer in enumerate(layer_params):
# log_print(f"[DEBUG] Client {client_id} → Layer {i} len: {len(layer)}, len item 1: {len(layer[0])}, len item 2: {len(layer[1])}")
param_list = []
for client_id, layer_params in client_params_dict.items():
flattened_tensor = torch.cat([
p.view(-1) for p in layer_params[layer_idx] # flatten each tensor in layer
]) # layer_tensor = param_layer_flatten[layer_idx] # flatten
# log_print(f"[DEBUG] Client {client_id} → Layer {layer_idx} param shape after flatten: {len(flattened_tensor)}", context="GATHER CLIENT PARAMS")
param_list.append(flattened_tensor)
if not param_list:
raise ValueError(f"[ERROR] Empty param_list for layer {layer_idx}. client_params_dict keys: {list(client_params_dict.keys())}")
# Stack all columns => shape (d_layer, K)
# for i, p in enumerate(param_list):
# log_print(f"[DEBUG] {i}th param_list shape: {len(p)}")
# log_print(f"param[{i}] flattened size = {[t.numel() for t in p]}")
# log_print(f"total flattened vector size = {sum(t.numel() for t in p)}")
param_matrix = torch.stack(param_list, dim=1)
return param_matrix
# --- replacement for separate_inv_spec_soft -------------------------------
@debug_function(context="SERVER")
def separate_inv_spec_mask(self, param_matrix, mask_spec):
"""
Non‑iterative split based on mask:
M_inv = col‑wise mean for *masked‑out* positions, broadcast to K
M_spec = (param_matrix − M_inv) * mask_spec
"""
d, K = param_matrix.shape
mask_spec = mask_spec.unsqueeze(1).expand(-1, K) # (d, K)
mask_inv = ~mask_spec
# 1) invariant part = average of the masked‑inv rows
inv_rows = mask_inv.any(dim=1)
M_inv = torch.zeros_like(param_matrix)
if inv_rows.any():
mean_inv = (param_matrix * mask_inv.float()).sum(dim=1, keepdim=True) \
/ mask_inv.float().sum(dim=1, keepdim=True).clamp_min(1.0)
M_inv[inv_rows] = mean_inv[inv_rows].repeat(1, K)
# 2) specific part = residual on the spec rows
M_spec = torch.zeros_like(param_matrix)
M_spec[mask_spec] = (param_matrix - M_inv)[mask_spec]
return M_inv, M_spec
# ---------------------------------------------------------------------------
@debug_function(context="SERVER")
def aggregate_invariant(self, M_inv, client_weights=None):
"""
Weighted column-wise average of M_inv => shape (d,).
This is the global domain-invariant parameter vector for that layer.
"""
d, K = M_inv.shape
if client_weights is None:
# default: uniform weighting
client_weights = [1.0 / K] * K
else:
# Normalize weights
total_weight = sum(client_weights)
client_weights = [w / total_weight for w in client_weights]
inv_agg = torch.zeros(d)
for j in range(K):
inv_agg += client_weights[j] * M_inv[:, j]
return inv_agg
@debug_function(context="SERVER")
def server_round(self, client_params_dict, num_layers, server_round, client_weights=None):
"""
For each layer:
1) gather param_matrix (d_l, K)
2) separate into domain-inv M_inv, domain-spec M_spec
3) aggregate M_inv across columns -> inv_agg
4) reconstruct new layer for each client: new_layer = inv_agg + M_spec[:, client_j]
Return updated_client_params => {cid: [layer0, layer1, ...]}
"""
# ★ train discriminator once per round
disc_loss = self.train_discriminator(client_params_dict)
# for cid in client_params_dict:
# log_print(f"[SERVER ROUND] Client {cid} params: {len(client_params_dict[cid])}")
updated_client_params = {cid: [] for cid in client_params_dict}
for layer_idx in range(num_layers):
# 1) gather
param_matrix = self.gather_client_params(client_params_dict, layer_idx)
# 2) separate
# ★ build mask from gradient saliency
mask_spec = self.build_masks(layer_idx, param_matrix)
M_inv, M_spec = self.separate_inv_spec_mask(param_matrix, mask_spec)
# 3) aggregate domain-invariant
inv_agg = self.aggregate_invariant(M_inv, client_weights)
self.M_spec_layer_wise[layer_idx] = M_spec.clone() # Store M_spec for reference
self.inv_agg[layer_idx] = inv_agg.clone() # Store inv_agg for reference
# 4) apply alpha-mix
blended_spec = self.alpha_mix(M_spec, layer_idx, alpha=0.8)
# 4) form new layer for each client
d_l, K = param_matrix.shape
all_client_ids = list(client_params_dict.keys())
for j, cid in enumerate(all_client_ids):
# The new layer is the sum of the aggregated invariant part
# plus the local domain-specific offset
new_layer_flat = inv_agg + blended_spec[:, j]
# shape is (d_l,) flattened => we can reshape if needed
# For demonstration, we'll keep them flattened in param_struct:
# Get the structure of the layer from client input
reference_layer = client_params_dict[cid][layer_idx] # List[Tensor]
new_layer_reshaped = reconstruct_layer_from_flat(new_layer_flat, reference_layer)
updated_client_params[cid].append(new_layer_reshaped) # append structured layer
# Optionally store the aggregated invariant for reference
self.global_invariant_store[layer_idx] = new_layer_flat.clone()
self.log_server_metrics({
"round": server_round,
"layer_idx": layer_idx,
"disc_loss": disc_loss,
"inv_norm": float(M_inv.norm().item()),
"spec_norm": float(M_spec.norm().item()),
"agg_norm": float(inv_agg.norm().item()),
"param_diversity": float(torch.std(param_matrix, dim=1).mean().item()),
})
return updated_client_params
class ServerDomainSpecHelper:
def __init__(self, server_obj: Server):
self.server = server_obj
@debug_function(context="SERVER DOMAIN SPEC")
def convert_M_spec_layers_to_clients(self, M_spec_dict):
"""
Convert {layer_idx: M_spec_l [d_l, K]} → list of client vectors [M_spec_client1, ..., M_spec_clientK]
"""
num_clients = self.server.num_clients
client_specs = []
for client_idx in range(num_clients):
client_vector_parts = []
for layer_idx in sorted(M_spec_dict.keys()):
M_spec_layer = M_spec_dict[layer_idx] # shape [d_l, K]
M_spec_client_l = M_spec_layer[:, client_idx] # shape [d_l]
client_vector_parts.append(M_spec_client_l)
# Flatten across all layers
M_spec_client = torch.cat(client_vector_parts)
client_specs.append(M_spec_client)
return client_specs
@debug_function(context="SERVER DOMAIN SPEC")
def compute_spec_distribution(self, M_spec_list):
"""
M_spec_list: List of domain-specific parameter tensors from trained domains.
Each tensor shape: [d] (flattened)
"""
stacked_spec = torch.stack(M_spec_list) # shape: [num_domains, d]
spec_mean = stacked_spec.mean(dim=0)
spec_std = stacked_spec.std(dim=0) + 1e-6 # small epsilon to avoid zero variance
return spec_mean, spec_std
# Example usage after federated training:
# M_spec_list = [M_spec_client1.flatten(), M_spec_client2.flatten(), ..., M_spec_clientK.flatten()]
# spec_mean, spec_std = compute_spec_distribution(M_spec_list)
@debug_function(context="SERVER DOMAIN SPEC")
def initialize_new_domain_spec(self, spec_mean, spec_std):
"""
Initialize new domain-specific parameters from learned distribution.
"""
new_spec = torch.normal(mean=spec_mean, std=spec_std)
return new_spec
# Usage:
# new_M_spec = initialize_new_domain_spec(spec_mean, spec_std)
# print(new_M_spec.shape) # [d]
@torch.no_grad()
def sample_mahalanobis(self, M_spec_list, epsilon=1.0):
"""
Return a brand-new spec that is ε-far (Mahalanobis) from the mean,
but still lies in the training ellipsoid.
"""
X = torch.stack(M_spec_list) # (K, d)
mu = X.mean(0)
cov = torch.cov(X.T) + 1e-6*torch.eye(X.size(1), device=X.device)
L = torch.linalg.cholesky(cov)
unit = torch.randn_like(mu)
unit = unit / unit.norm() # random direction
new = mu + L @ unit * epsilon
return new