1+ # This Source Code Form is subject to the terms of the Mozilla Public
2+ # License, v. 2.0. If a copy of the MPL was not distributed with this
3+ # file, You can obtain one at https://mozilla.org/MPL/2.0/.
4+ #
5+ # This Source Code Form is "Incompatible With Secondary Licenses", as
6+ # defined by the Mozilla Public License, v. 2.0.
7+
8+ from .engine import Engine
9+ from ..modeling .dl_models import DL_Models as dlm
10+ import numpy as np
11+
12+ import torch
13+ import torch .nn as nn
14+ import torch .nn .functional as F
15+ import math
16+ import copy
17+ from math import floor
18+ from tqdm import tqdm
19+
20+
21+ class DLLABoostingEnsemble :
22+ def __init__ (self , base_model_fn , n_estimators = 3 , lr = 0.001 , device = None ):
23+ self .base_model_fn = base_model_fn
24+ self .n_estimators = n_estimators
25+ self .lr = lr
26+ self .sensitivity = None
27+ self .models = []
28+ self .device = device if device else torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
29+ self .criterion = nn .BCELoss ()
30+ self .epoch_batches = []
31+ self .current_stage = 0
32+ self .epochs_per_stage = 1
33+ self .ready = False
34+ self .curr_model = None
35+ self .optimizer = None
36+ self ._fixed_epoch_batches = None
37+
38+
39+ def begin_stage (self ):
40+ self .curr_model = self .base_model_fn ().to (self .device )
41+ self .optimizer = torch .optim .Adam (self .curr_model .parameters (), lr = 1e-3 )
42+ self .ready = True
43+
44+
45+ def update (self , X_batch , y_batch , log_batch = True ):
46+ if not self .ready :
47+ self .begin_stage ()
48+
49+ X_batch = X_batch .to (self .device ).float ()
50+ y_batch = y_batch .to (self .device ).long ()
51+
52+ # compute current ensemble prediction (logits)
53+ with torch .no_grad ():
54+ ensemble_logits = torch .zeros (X_batch .size (0 ), 2 ).to (self .device )
55+ for model in self .models :
56+ ensemble_logits += self .lr * model (X_batch )
57+
58+ if len (self .models ) > 0 :
59+ ensemble_logits /= (self .lr * len (self .models ))
60+ else :
61+ ensemble_logits = torch .full_like (y_batch , 0.5 )
62+
63+ # compute residual pseudo-targets (added stability)
64+ target = (y_batch - ensemble_logits ).detach ()
65+ target = (target + 1.0 ) / 2.0
66+ target = target .clamp (0.0 , 1.0 )
67+
68+ # train current model
69+ self .curr_model .train ()
70+ pred_logits = self .curr_model (X_batch )
71+
72+ eps = 1e-5
73+ pred_logits = pred_logits .clamp (eps , 1. - eps )
74+
75+ loss = self .criterion (pred_logits , target )
76+
77+ self .optimizer .zero_grad ()
78+ loss .backward ()
79+ self .optimizer .step ()
80+
81+
82+ def end_stage (self ):
83+ # save only model's state_dict
84+ self .models .append (copy .deepcopy (self .curr_model ).eval ())
85+ torch .cuda .empty_cache () # helps prevent CUDA OOM
86+ self .current_stage += 1
87+ self .ready = False
88+
89+
90+ def finish_training (self ):
91+ if self ._fixed_epoch_batches is None :
92+ # copy batches to CPU and detach
93+ self ._fixed_epoch_batches = [
94+ (X .detach ().cpu (), y .detach ().cpu ())
95+ for (X , y ) in self .epoch_batches
96+ ]
97+
98+ for _ in range (self .n_estimators - self .current_stage ):
99+ self .begin_stage ()
100+ for _ in range (self .epochs_per_stage ):
101+ for X , y in self ._fixed_epoch_batches :
102+ self .update (X , y , log_batch = False )
103+ self .end_stage ()
104+
105+
106+ def predict (self , X ):
107+ X = X .to (self .device ).float ()
108+ logits = torch .zeros ((X .size (0 ), 2 ), device = self .device )
109+ for model in self .models :
110+ logits += self .lr * model (X )
111+ probs = F .softmax (logits , dim = 1 )
112+
113+ # return class index
114+ return torch .argmax (probs , dim = 1 )
115+
116+
117+ def compute_sensitivity (self , X_input ):
118+ self .curr_model .eval ()
119+ X_input = X_input .to (self .device ).float ()
120+ X_input .requires_grad = True
121+
122+ # accumulate logits from all ensemble models
123+ ensemble_logits = torch .zeros ((X_input .size (0 ), 2 ), device = self .device )
124+ for model in self .models :
125+ model .eval ()
126+ logits = model (X_input )
127+ ensemble_logits += self .lr * logits
128+ ensemble_logits /= (self .lr * len (self .models ))
129+
130+ # use the logit difference for class sensitivity
131+ class_diff = ensemble_logits [:, 1 ] - ensemble_logits [:, 0 ]
132+
133+ # backpropagate to input features
134+ class_diff .sum ().backward ()
135+ gradients = X_input .grad
136+
137+ # aggregate absolute gradients across all samples
138+ sensitivity_scores = gradients .abs ().mean (dim = 0 ).detach ().cpu ().numpy ()
139+ return sensitivity_scores
140+
141+
142+
143+
144+ class Boosting (Engine ):
145+ def __init__ (self , model_type , num_estimators , train_float , num_epochs ) -> None :
146+ # construction parameters
147+ self .model_type = model_type
148+ self .train_float = train_float
149+ self .num_epochs = num_epochs
150+ self .num_estimators = num_estimators
151+ # initialize values needed
152+ self .samples_len = 0
153+ self .batch_size = 0
154+ self .traces_len = 0
155+ self .batches_num = 0
156+ self .counted_batches = 0
157+ self .data_dtype = None
158+ self .sensitivity = None
159+ self .sens_tensor = None
160+ self .p_value = 0
161+ # validation values
162+ self .accuracy = 0
163+ self .actual_labels = None
164+ self .pred_labels = None
165+ self .predicted_classes = None
166+
167+
168+ def populate (self , container ):
169+ # initialize dimensional variables
170+ self .samples_len = container .min_samples_length
171+ self .traces_len = container .min_traces_length
172+ self .batch_size = container .data .batch_size
173+ self .batches_num = int (self .traces_len / self .batch_size )
174+ # assign per-tile train and validation data
175+ for tile in container .tiles :
176+ (tile_x , tile_y ) = tile
177+ # config batches
178+ container .configure (tile_x , tile_y , [0 ])
179+ container .configure2 (tile_x , tile_y , [0 ])
180+
181+
182+ def fetch_training_batch (self , container , i ):
183+ batch1 = container .get_batch_index (i )[- 1 ]
184+ batch2 = container .get_batch_index2 (i )[- 1 ]
185+ current_data = np .concatenate ((batch1 , batch2 ), axis = 0 )
186+ label1 = np .zeros (len (batch1 ))
187+ label2 = np .ones (len (batch2 ))
188+ current_labels = np .concatenate ((label1 , label2 ), axis = 0 )
189+ current_labels = np .eye (2 )[current_labels .astype (int )] # one-hot encode labels
190+ return current_data , current_labels
191+
192+
193+ def fetch_validation_batch (self , container , i , batch_size ):
194+ batch1 = container .get_batch_index (i )[- 1 ]
195+ batch2 = container .get_batch_index2 (i )[- 1 ]
196+ current_data = np .concatenate ((batch1 , batch2 ), axis = 0 )
197+ label1 = np .zeros (batch_size )
198+ label2 = np .ones (batch_size )
199+ current_labels = np .concatenate ((label1 , label2 ), axis = 0 )
200+ return current_data , current_labels
201+
202+
203+ def train_ensemble (self , container ):
204+ num_batches = floor ((self .traces_len / self .batch_size ) * self .train_float )
205+ print (f"Training { self .num_estimators } estimators on { num_batches } batches" )
206+
207+ # feed batches
208+ for i in tqdm (range (num_batches ), desc = "Processing batches" ):
209+ data_np , labels_oh = self .fetch_training_batch (container , i )
210+ X = torch .tensor (data_np , dtype = torch .float32 )
211+ y = torch .tensor (labels_oh , dtype = torch .float32 )
212+ self .model .update (X , y )
213+ self .counted_batches += 1
214+
215+ self .model .finish_training ()
216+
217+ print (f"Finished training { self .num_estimators } models." )
218+
219+
220+ def validate_ensemble (self , container ):
221+ num_val_batches = self .batches_num - self .counted_batches
222+ print (f'Validating on { num_val_batches } batches' )
223+
224+ X_new = np .empty ((2 * num_val_batches * self .batch_size , self .samples_len ))
225+ Y_test = np .empty ((2 * num_val_batches * self .batch_size ))
226+
227+ for i in tqdm (range (num_val_batches ), desc = "Processing batches" ):
228+ current_data , current_labels = self .fetch_validation_batch (container , i + int (self .batches_num * self .train_float ), self .batch_size )
229+
230+ start_idx = i * self .batch_size
231+ end_idx = start_idx + 2 * self .batch_size
232+ X_new [start_idx :end_idx ] = current_data
233+ Y_test [start_idx :end_idx ] = current_labels
234+
235+ # save labels
236+ self .actual_labels = Y_test [:]
237+ # make new data into tensors
238+ X_new_tensor = torch .tensor (X_new , dtype = torch .float32 )
239+
240+ # make predictions
241+ preds = self .model .predict (X_new_tensor )
242+ preds = preds .cpu ().numpy (force = True )
243+
244+ # calculate accuracy
245+ correct_predictions = np .sum (preds == Y_test )
246+ self .accuracy = correct_predictions / len (Y_test )
247+
248+ print (f"Made { preds .shape [0 ]} predictions with { self .accuracy :.2%} accuracy using the { self .model_type } model." )
249+
250+ # sensitivity stuff
251+ self .sens_tensor = X_new_tensor
252+
253+
254+ def run (self , container , model_building = False , model_validation = False ):
255+ # training
256+ if model_building :
257+ self .populate (container )
258+ # initialize boosting ensemble
259+ self .model = DLLABoostingEnsemble (
260+ base_model_fn = lambda : dlm .eMLP (self .samples_len ),
261+ n_estimators = self .num_estimators ,
262+ lr = 0.001 ,
263+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
264+ )
265+ self .train_ensemble (container )
266+ # validation
267+ if model_validation :
268+ self .validate_ensemble (container )
269+
270+
271+ def get_sensitivity (self ):
272+ self .sensitivity = self .model .compute_sensitivity (self .sens_tensor )
273+ return self .sensitivity
274+
275+
276+ def get_accuracy (self ):
277+ return self .accuracy
278+
279+
280+ # ===== p-value and leakage stuff =====
281+ def binom_log_pmf (self ,k , n , p ):
282+ if p == 0.0 : return float ('-inf' ) if k > 0 else 0.0
283+ if p == 1.0 : return float ('-inf' ) if k < n else 0.0
284+ return (
285+ math .lgamma (n + 1 ) - math .lgamma (k + 1 ) - math .lgamma (n - k + 1 )
286+ + k * math .log (p )
287+ + (n - k ) * math .log (1 - p )
288+ )
289+
290+
291+ def logsumexp (self , log_probs ):
292+ max_log = max (log_probs )
293+ return max_log + math .log (sum (math .exp (lp - max_log ) for lp in log_probs ))
294+
295+
296+ def get_log10_binom_tail (self , k_min , n , p ):
297+ if k_min > n : return float ('-inf' ) # log(0)
298+
299+ log_probs = [self .binom_log_pmf (k , n , p ) for k in range (k_min , n + 1 )]
300+ log_p_value = self .logsumexp (log_probs )
301+ log10_p_value = log_p_value / math .log (10 ) # convert ln(p) to log10(p)
302+ return - log10_p_value
303+
304+
305+ def get_leakage (self , p_th = 1e-5 ):
306+ M = self .traces_len - self .counted_batches * self .batch_size
307+ sM = int (np .floor (self .accuracy * M ))
308+ sM = max (0 , min (sM , M ))
309+
310+ # compute -log10(p)
311+ neg_log10_p = self .get_log10_binom_tail (sM , M , 0.5 )
312+ self .p_value = 10 ** (- neg_log10_p ) # only for comparison/display
313+
314+ if self .p_value <= p_th :
315+ print (f"Leakage detected: p-value ≈ { self .p_value :.2e} , -log10(p) ≈ { neg_log10_p :.2f} " )
316+ else :
317+ print (f"No significant leakage: p-value ≈ { self .p_value :.2e} , -log10(p) ≈ { neg_log10_p :.2f} " )
318+
319+ return self .p_value , neg_log10_p
0 commit comments