22# Copyright lowRISC contributors.
33# Licensed under the Apache License, Version 2.0, see LICENSE for details.
44# SPDX-License-Identifier: Apache-2.0
5+ """A distributed implementation of the correlation-enhanced power analysis
6+ collision attack.
7+
8+ See "Correlation-Enhanced Power Analysis Collision Attack" by A. Moradi, O.
9+ Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf) for more
10+ information.
11+
12+ Typical usage:
13+ >>> ./ceca.py -f PROJECT_FILE -n 400000 -w 5 -a 117 127 -d output -s 3
14+ """
515
616import argparse
717import enum
2535from capture .project_library .project import ProjectConfig # noqa : E402
2636from capture .project_library .project import SCAProject # noqa : E402
2737
28- """A distributed implementation of the correlation-enhanced power analysis
29- collision attack.
30-
31- See "Correlation-Enhanced Power Analysis Collision Attack" by A. Moradi, O.
32- Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf) for more
33- information.
34-
35- Typical usage:
36- >>> ./ceca.py -f PROJECT_FILE -n 400000 -w 5 -a 117 127 -d output -s 3
37- """
38-
3938
4039def timer ():
4140 """A customization of the ``codetiming.Timer`` decorator."""
4241
4342 def decorator (func ):
43+
4444 @codetiming .Timer (
4545 name = func .__name__ ,
4646 text = f"{ func .__name__ } took {{seconds:.1f}}s" ,
@@ -79,7 +79,8 @@ class TraceWorker:
7979 >>> results = ray.get(tasks)
8080 """
8181
82- def __init__ (self , project_file , trace_slice , attack_window , attack_direction ):
82+ def __init__ (self , project_file , trace_slice , attack_window ,
83+ attack_direction ):
8384 """Inits a TraceWorker.
8485
8586 Args:
@@ -94,26 +95,27 @@ def __init__(self, project_file, trace_slice, attack_window, attack_direction):
9495 project_type = "ot_trace_library"
9596
9697 # Open the project.
97- project_cfg = ProjectConfig (
98- type = project_type , path = project_file , wave_dtype = np .uint16 , overwrite = False
99- )
98+ project_cfg = ProjectConfig (type = project_type ,
99+ path = project_file ,
100+ wave_dtype = np .uint16 ,
101+ overwrite = False )
100102 self .project = SCAProject (project_cfg )
101103 self .project .open_project ()
102104
103105 # TODO: Consider more efficient formats.
104106 self .num_samples = attack_window .stop - attack_window .start
105107 if attack_direction == AttackDirection .INPUT :
106108 self .texts = np .vstack (
107- self .project .get_plaintexts (trace_slice .start , trace_slice . stop )
108- )
109+ self .project .get_plaintexts (trace_slice .start ,
110+ trace_slice . stop ) )
109111 else :
110112 self .texts = np .vstack (
111- self .project .get_ciphertexts (trace_slice .start , trace_slice . stop )
112- )
113+ self .project .get_ciphertexts (trace_slice .start ,
114+ trace_slice . stop ) )
113115
114116 self .traces = np .asarray (
115- self .project .get_waves (trace_slice .start , trace_slice . stop )
116- )[:, attack_window ]
117+ self .project .get_waves (trace_slice .start ,
118+ trace_slice . stop ) )[:, attack_window ]
117119
118120 self .project .close (save = False )
119121
@@ -132,7 +134,7 @@ def compute_stats(self):
132134 cnt = self .traces .shape [0 ]
133135 sum_ = self .traces .sum (axis = 0 )
134136 mean = sum_ / cnt
135- sum_dev_prods = ((self .traces - mean ) ** 2 ).sum (axis = 0 )
137+ sum_dev_prods = ((self .traces - mean )** 2 ).sum (axis = 0 )
136138 return (cnt , sum_ , sum_dev_prods )
137139
138140 def filter_noisy_traces (self , min_trace , max_trace ):
@@ -146,8 +148,7 @@ def filter_noisy_traces(self, min_trace, max_trace):
146148 Number of remaining traces.
147149 """
148150 traces_to_use = np .all (
149- (self .traces >= min_trace ) & (self .traces <= max_trace ), axis = 1
150- )
151+ (self .traces >= min_trace ) & (self .traces <= max_trace ), axis = 1 )
151152 self .traces = self .traces [traces_to_use ]
152153 self .texts = self .texts [traces_to_use ]
153154 return self .traces .shape [0 ]
@@ -214,12 +215,12 @@ def compute_mean_and_std(workers):
214215 running_cnt += cnt
215216 else :
216217 running_sum_dev_prods += sum_dev_prods + (
217- (cnt * running_sum - running_cnt * sum_ ) ** 2 /
218- (cnt * running_cnt * (cnt + running_cnt ))
219- )
218+ (cnt * running_sum - running_cnt * sum_ )** 2 /
219+ (cnt * running_cnt * (cnt + running_cnt )))
220220 running_sum += sum_
221221 running_cnt += cnt
222- return running_sum / running_cnt , np .sqrt (running_sum_dev_prods / running_cnt )
222+ return running_sum / running_cnt , np .sqrt (running_sum_dev_prods /
223+ running_cnt )
223224
224225
225226def filter_noisy_traces (workers , mean_trace , std_trace , max_std ):
@@ -237,7 +238,8 @@ def filter_noisy_traces(workers, mean_trace, std_trace, max_std):
237238 min_trace = mean_trace - max_std * std_trace
238239 max_trace = mean_trace + max_std * std_trace
239240 tasks = [
240- worker .filter_noisy_traces .remote (min_trace , max_trace ) for worker in workers
241+ worker .filter_noisy_traces .remote (min_trace , max_trace )
242+ for worker in workers
241243 ]
242244
243245 running_cnt = 0
@@ -392,7 +394,8 @@ def find_best_diffs(pairwise_diffs_scores):
392394 # the most likely differences between key bytes.
393395 G .add_edge (a , b , weight = DiffScore (pairwise_diffs_scores [a , b , 1 ]))
394396 # Find paths from key byte 0 to all other bytes.
395- paths = nx .algorithms .shortest_paths .weighted .single_source_dijkstra_path (G , 0 )
397+ paths = nx .algorithms .shortest_paths .weighted .single_source_dijkstra_path (
398+ G , 0 )
396399 # Recover the paths and corresponding differences from key byte 0 to all
397400 # other bytes.
398401 diffs = np .zeros (16 , dtype = np .uint8 )
@@ -422,9 +425,11 @@ def recover_key(diffs, attack_direction, plaintext, ciphertext):
422425 # Create a matrix of all possible keys.
423426 keys = np .zeros ((256 , 16 ), np .uint8 )
424427 for first_byte_val in range (256 ):
425- key = np .asarray ([diffs [i ] ^ first_byte_val for i in range (16 )], np .uint8 )
428+ key = np .asarray ([diffs [i ] ^ first_byte_val for i in range (16 )],
429+ np .uint8 )
426430 if attack_direction == AttackDirection .OUTPUT :
427- key = np .asarray (cwa .aes_funcs .key_schedule_rounds (key , 10 , 0 ), np .uint8 )
431+ key = np .asarray (cwa .aes_funcs .key_schedule_rounds (key , 10 , 0 ),
432+ np .uint8 )
428433 keys [first_byte_val ] = key
429434 # Encrypt the plaintext using all candidates in parallel.
430435 ciphertexts = scared .aes .base .encrypt (plaintext , keys )
@@ -464,9 +469,8 @@ def compare_diffs(pairwise_diffs_scores, attack_direction, correct_key):
464469
465470
466471@timer ()
467- def perform_attack (
468- project_file , num_traces , attack_window , attack_direction , max_std , num_workers
469- ):
472+ def perform_attack (project_file , num_traces , attack_window , attack_direction ,
473+ max_std , num_workers ):
470474 """Performs a correlation-enhanced power analysis collision attack.
471475
472476 This function:
@@ -506,9 +510,10 @@ def perform_attack(
506510 project_type = "ot_trace_library"
507511
508512 # Open the project.
509- project_cfg = ProjectConfig (
510- type = project_type , path = project_file , wave_dtype = np .uint16 , overwrite = False
511- )
513+ project_cfg = ProjectConfig (type = project_type ,
514+ path = project_file ,
515+ wave_dtype = np .uint16 ,
516+ overwrite = False )
512517 project = SCAProject (project_cfg )
513518 project .open_project ()
514519
@@ -524,11 +529,11 @@ def perform_attack(
524529 f"Invalid attack window: { attack_window } (must be in [0, { last_sample } ])"
525530 )
526531 if max_std <= 0 :
527- raise ValueError (f"Invalid max_std: { max_std } (must be greater than zero)" )
532+ raise ValueError (
533+ f"Invalid max_std: { max_std } (must be greater than zero)" )
528534 if num_workers <= 0 :
529535 raise ValueError (
530- f"Invalid num_workers: { num_workers } (must be greater than zero)"
531- )
536+ f"Invalid num_workers: { num_workers } (must be greater than zero)" )
532537
533538 # Instantiate workers
534539 def worker_trace_slices ():
@@ -539,15 +544,15 @@ def worker_trace_slices():
539544 traces_per_worker = int (num_traces / num_workers )
540545 first_worker_num_traces = traces_per_worker + num_traces % num_workers
541546 yield slice (0 , first_worker_num_traces )
542- for trace_begin in range (
543- first_worker_num_traces , num_traces , traces_per_worker
544- ):
547+ for trace_begin in range (first_worker_num_traces , num_traces ,
548+ traces_per_worker ):
545549 yield slice (trace_begin , trace_begin + traces_per_worker )
546550
547551 # Attack window is inclusive.
548552 attack_window = slice (attack_window [0 ], attack_window [1 ] + 1 )
549553 workers = [
550- TraceWorker .remote (project_file , trace_slice , attack_window , attack_direction )
554+ TraceWorker .remote (project_file , trace_slice , attack_window ,
555+ attack_direction )
551556 for trace_slice in worker_trace_slices ()
552557 ]
553558 assert len (workers ) == num_workers
@@ -556,32 +561,27 @@ def worker_trace_slices():
556561 # Filter noisy traces.
557562 orig_num_traces = num_traces
558563 num_traces = filter_noisy_traces (workers , mean , std_dev , max_std )
559- logging .info (
560- f"Will use { num_traces } traces "
561- f"({ 100 * num_traces / orig_num_traces :.1f} % of all traces)"
562- )
564+ logging .info (f"Will use { num_traces } traces "
565+ f"({ 100 * num_traces / orig_num_traces :.1f} % of all traces)" )
563566 # Mean traces for all values of all text bytes.
564567 mean_text_traces = compute_mean_text_traces (workers )
565568 # Guess the differences between key bytes.
566569 pairwise_diffs_scores = compute_pairwise_diffs_and_scores (mean_text_traces )
567570 diffs = find_best_diffs (pairwise_diffs_scores )
568571 logging .info (f"Difference values (delta_0_i): { diffs } " )
569572 # Recover the key.
570- key = recover_key (
571- diffs , attack_direction , project .get_plaintexts (0 ), project .get_ciphertexts (0 )
572- )
573+ key = recover_key (diffs , attack_direction , project .get_plaintexts (0 ),
574+ project .get_ciphertexts (0 ))
573575 if key is not None :
574576 logging .info (f"Recovered AES key: { bytes (key ).hex ()} " )
575577 else :
576578 logging .error ("Failed to recover the AES key" )
577579 # Compare differences - both matrices are symmetric and have an all-zero main diagonal.
578- correct_diffs = compare_diffs (
579- pairwise_diffs_scores , attack_direction , project .get_keys (0 )
580- )
580+ correct_diffs = compare_diffs (pairwise_diffs_scores , attack_direction ,
581+ project .get_keys (0 ))
581582 logging .info (
582583 f"Recovered { ((np .sum (correct_diffs ) - 16 ) / 2 ).astype (int )} /120 "
583- "differences between key bytes"
584- )
584+ "differences between key bytes" )
585585 project .close (save = False )
586586 return key
587587
@@ -591,8 +591,7 @@ def parse_args():
591591 parser = argparse .ArgumentParser (
592592 description = """A distributed implementation of the attack described in
593593 "Correlation-Enhanced Power Analysis Collision Attack" by A. Moradi, O.
594- Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf)."""
595- )
594+ Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf).""" )
596595 parser .add_argument (
597596 "-f" ,
598597 "--project-file" ,
@@ -649,8 +648,7 @@ def config_logger():
649648 sh = logging .StreamHandler ()
650649 sh .setLevel (logging .INFO )
651650 formatter = logging .Formatter (
652- "%(asctime)s %(levelname)s %(filename)s:%(lineno)d -- %(message)s"
653- )
651+ "%(asctime)s %(levelname)s %(filename)s:%(lineno)d -- %(message)s" )
654652 sh .setFormatter (formatter )
655653 logger .addHandler (sh )
656654 return logger
@@ -665,9 +663,9 @@ def main():
665663 ray .init (
666664 runtime_env = {
667665 "working_dir" : "../" ,
668- "excludes" : [ "*.db" , "*.cwp" , "*.npy" , "*.bit" , "*/lfs/*" , "*.pack" ],
669- }
670- )
666+ "excludes" :
667+ [ "*.db" , "*.cwp" , "*.npy" , "*.bit" , "*/lfs/*" , "*.pack" ],
668+ } )
671669
672670 key = perform_attack (** vars (args ))
673671 sys .exit (0 if key is not None else 1 )
0 commit comments