Skip to content

Commit ced6d2f

Browse files
committed
Autolint python files using yapf
To keep things consistent, use the same auto python linter as we use in OpenTitan (yapf) Signed-off-by: Pascal Nasahl <nasahlpa@lowrisc.org>
1 parent 75741d1 commit ced6d2f

66 files changed

Lines changed: 2758 additions & 2839 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

analysis/ceca.py

Lines changed: 62 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
# Copyright lowRISC contributors.
33
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
44
# SPDX-License-Identifier: Apache-2.0
5+
"""A distributed implementation of the correlation-enhanced power analysis
6+
collision attack.
7+
8+
See "Correlation-Enhanced Power Analysis Collision Attack" by A. Moradi, O.
9+
Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf) for more
10+
information.
11+
12+
Typical usage:
13+
>>> ./ceca.py -f PROJECT_FILE -n 400000 -w 5 -a 117 127 -d output -s 3
14+
"""
515

616
import argparse
717
import enum
@@ -25,22 +35,12 @@
2535
from capture.project_library.project import ProjectConfig # noqa : E402
2636
from capture.project_library.project import SCAProject # noqa : E402
2737

28-
"""A distributed implementation of the correlation-enhanced power analysis
29-
collision attack.
30-
31-
See "Correlation-Enhanced Power Analysis Collision Attack" by A. Moradi, O.
32-
Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf) for more
33-
information.
34-
35-
Typical usage:
36-
>>> ./ceca.py -f PROJECT_FILE -n 400000 -w 5 -a 117 127 -d output -s 3
37-
"""
38-
3938

4039
def timer():
4140
"""A customization of the ``codetiming.Timer`` decorator."""
4241

4342
def decorator(func):
43+
4444
@codetiming.Timer(
4545
name=func.__name__,
4646
text=f"{func.__name__} took {{seconds:.1f}}s",
@@ -79,7 +79,8 @@ class TraceWorker:
7979
>>> results = ray.get(tasks)
8080
"""
8181

82-
def __init__(self, project_file, trace_slice, attack_window, attack_direction):
82+
def __init__(self, project_file, trace_slice, attack_window,
83+
attack_direction):
8384
"""Inits a TraceWorker.
8485
8586
Args:
@@ -94,26 +95,27 @@ def __init__(self, project_file, trace_slice, attack_window, attack_direction):
9495
project_type = "ot_trace_library"
9596

9697
# Open the project.
97-
project_cfg = ProjectConfig(
98-
type=project_type, path=project_file, wave_dtype=np.uint16, overwrite=False
99-
)
98+
project_cfg = ProjectConfig(type=project_type,
99+
path=project_file,
100+
wave_dtype=np.uint16,
101+
overwrite=False)
100102
self.project = SCAProject(project_cfg)
101103
self.project.open_project()
102104

103105
# TODO: Consider more efficient formats.
104106
self.num_samples = attack_window.stop - attack_window.start
105107
if attack_direction == AttackDirection.INPUT:
106108
self.texts = np.vstack(
107-
self.project.get_plaintexts(trace_slice.start, trace_slice.stop)
108-
)
109+
self.project.get_plaintexts(trace_slice.start,
110+
trace_slice.stop))
109111
else:
110112
self.texts = np.vstack(
111-
self.project.get_ciphertexts(trace_slice.start, trace_slice.stop)
112-
)
113+
self.project.get_ciphertexts(trace_slice.start,
114+
trace_slice.stop))
113115

114116
self.traces = np.asarray(
115-
self.project.get_waves(trace_slice.start, trace_slice.stop)
116-
)[:, attack_window]
117+
self.project.get_waves(trace_slice.start,
118+
trace_slice.stop))[:, attack_window]
117119

118120
self.project.close(save=False)
119121

@@ -132,7 +134,7 @@ def compute_stats(self):
132134
cnt = self.traces.shape[0]
133135
sum_ = self.traces.sum(axis=0)
134136
mean = sum_ / cnt
135-
sum_dev_prods = ((self.traces - mean) ** 2).sum(axis=0)
137+
sum_dev_prods = ((self.traces - mean)**2).sum(axis=0)
136138
return (cnt, sum_, sum_dev_prods)
137139

138140
def filter_noisy_traces(self, min_trace, max_trace):
@@ -146,8 +148,7 @@ def filter_noisy_traces(self, min_trace, max_trace):
146148
Number of remaining traces.
147149
"""
148150
traces_to_use = np.all(
149-
(self.traces >= min_trace) & (self.traces <= max_trace), axis=1
150-
)
151+
(self.traces >= min_trace) & (self.traces <= max_trace), axis=1)
151152
self.traces = self.traces[traces_to_use]
152153
self.texts = self.texts[traces_to_use]
153154
return self.traces.shape[0]
@@ -214,12 +215,12 @@ def compute_mean_and_std(workers):
214215
running_cnt += cnt
215216
else:
216217
running_sum_dev_prods += sum_dev_prods + (
217-
(cnt * running_sum - running_cnt * sum_) ** 2 /
218-
(cnt * running_cnt * (cnt + running_cnt))
219-
)
218+
(cnt * running_sum - running_cnt * sum_)**2 /
219+
(cnt * running_cnt * (cnt + running_cnt)))
220220
running_sum += sum_
221221
running_cnt += cnt
222-
return running_sum / running_cnt, np.sqrt(running_sum_dev_prods / running_cnt)
222+
return running_sum / running_cnt, np.sqrt(running_sum_dev_prods /
223+
running_cnt)
223224

224225

225226
def filter_noisy_traces(workers, mean_trace, std_trace, max_std):
@@ -237,7 +238,8 @@ def filter_noisy_traces(workers, mean_trace, std_trace, max_std):
237238
min_trace = mean_trace - max_std * std_trace
238239
max_trace = mean_trace + max_std * std_trace
239240
tasks = [
240-
worker.filter_noisy_traces.remote(min_trace, max_trace) for worker in workers
241+
worker.filter_noisy_traces.remote(min_trace, max_trace)
242+
for worker in workers
241243
]
242244

243245
running_cnt = 0
@@ -392,7 +394,8 @@ def find_best_diffs(pairwise_diffs_scores):
392394
# the most likely differences between key bytes.
393395
G.add_edge(a, b, weight=DiffScore(pairwise_diffs_scores[a, b, 1]))
394396
# Find paths from key byte 0 to all other bytes.
395-
paths = nx.algorithms.shortest_paths.weighted.single_source_dijkstra_path(G, 0)
397+
paths = nx.algorithms.shortest_paths.weighted.single_source_dijkstra_path(
398+
G, 0)
396399
# Recover the paths and corresponding differences from key byte 0 to all
397400
# other bytes.
398401
diffs = np.zeros(16, dtype=np.uint8)
@@ -422,9 +425,11 @@ def recover_key(diffs, attack_direction, plaintext, ciphertext):
422425
# Create a matrix of all possible keys.
423426
keys = np.zeros((256, 16), np.uint8)
424427
for first_byte_val in range(256):
425-
key = np.asarray([diffs[i] ^ first_byte_val for i in range(16)], np.uint8)
428+
key = np.asarray([diffs[i] ^ first_byte_val for i in range(16)],
429+
np.uint8)
426430
if attack_direction == AttackDirection.OUTPUT:
427-
key = np.asarray(cwa.aes_funcs.key_schedule_rounds(key, 10, 0), np.uint8)
431+
key = np.asarray(cwa.aes_funcs.key_schedule_rounds(key, 10, 0),
432+
np.uint8)
428433
keys[first_byte_val] = key
429434
# Encrypt the plaintext using all candidates in parallel.
430435
ciphertexts = scared.aes.base.encrypt(plaintext, keys)
@@ -464,9 +469,8 @@ def compare_diffs(pairwise_diffs_scores, attack_direction, correct_key):
464469

465470

466471
@timer()
467-
def perform_attack(
468-
project_file, num_traces, attack_window, attack_direction, max_std, num_workers
469-
):
472+
def perform_attack(project_file, num_traces, attack_window, attack_direction,
473+
max_std, num_workers):
470474
"""Performs a correlation-enhanced power analysis collision attack.
471475
472476
This function:
@@ -506,9 +510,10 @@ def perform_attack(
506510
project_type = "ot_trace_library"
507511

508512
# Open the project.
509-
project_cfg = ProjectConfig(
510-
type=project_type, path=project_file, wave_dtype=np.uint16, overwrite=False
511-
)
513+
project_cfg = ProjectConfig(type=project_type,
514+
path=project_file,
515+
wave_dtype=np.uint16,
516+
overwrite=False)
512517
project = SCAProject(project_cfg)
513518
project.open_project()
514519

@@ -524,11 +529,11 @@ def perform_attack(
524529
f"Invalid attack window: {attack_window} (must be in [0, {last_sample}])"
525530
)
526531
if max_std <= 0:
527-
raise ValueError(f"Invalid max_std: {max_std} (must be greater than zero)")
532+
raise ValueError(
533+
f"Invalid max_std: {max_std} (must be greater than zero)")
528534
if num_workers <= 0:
529535
raise ValueError(
530-
f"Invalid num_workers: {num_workers} (must be greater than zero)"
531-
)
536+
f"Invalid num_workers: {num_workers} (must be greater than zero)")
532537

533538
# Instantiate workers
534539
def worker_trace_slices():
@@ -539,15 +544,15 @@ def worker_trace_slices():
539544
traces_per_worker = int(num_traces / num_workers)
540545
first_worker_num_traces = traces_per_worker + num_traces % num_workers
541546
yield slice(0, first_worker_num_traces)
542-
for trace_begin in range(
543-
first_worker_num_traces, num_traces, traces_per_worker
544-
):
547+
for trace_begin in range(first_worker_num_traces, num_traces,
548+
traces_per_worker):
545549
yield slice(trace_begin, trace_begin + traces_per_worker)
546550

547551
# Attack window is inclusive.
548552
attack_window = slice(attack_window[0], attack_window[1] + 1)
549553
workers = [
550-
TraceWorker.remote(project_file, trace_slice, attack_window, attack_direction)
554+
TraceWorker.remote(project_file, trace_slice, attack_window,
555+
attack_direction)
551556
for trace_slice in worker_trace_slices()
552557
]
553558
assert len(workers) == num_workers
@@ -556,32 +561,27 @@ def worker_trace_slices():
556561
# Filter noisy traces.
557562
orig_num_traces = num_traces
558563
num_traces = filter_noisy_traces(workers, mean, std_dev, max_std)
559-
logging.info(
560-
f"Will use {num_traces} traces "
561-
f"({100 * num_traces / orig_num_traces:.1f}% of all traces)"
562-
)
564+
logging.info(f"Will use {num_traces} traces "
565+
f"({100 * num_traces / orig_num_traces:.1f}% of all traces)")
563566
# Mean traces for all values of all text bytes.
564567
mean_text_traces = compute_mean_text_traces(workers)
565568
# Guess the differences between key bytes.
566569
pairwise_diffs_scores = compute_pairwise_diffs_and_scores(mean_text_traces)
567570
diffs = find_best_diffs(pairwise_diffs_scores)
568571
logging.info(f"Difference values (delta_0_i): {diffs}")
569572
# Recover the key.
570-
key = recover_key(
571-
diffs, attack_direction, project.get_plaintexts(0), project.get_ciphertexts(0)
572-
)
573+
key = recover_key(diffs, attack_direction, project.get_plaintexts(0),
574+
project.get_ciphertexts(0))
573575
if key is not None:
574576
logging.info(f"Recovered AES key: {bytes(key).hex()}")
575577
else:
576578
logging.error("Failed to recover the AES key")
577579
# Compare differences - both matrices are symmetric and have an all-zero main diagonal.
578-
correct_diffs = compare_diffs(
579-
pairwise_diffs_scores, attack_direction, project.get_keys(0)
580-
)
580+
correct_diffs = compare_diffs(pairwise_diffs_scores, attack_direction,
581+
project.get_keys(0))
581582
logging.info(
582583
f"Recovered {((np.sum(correct_diffs) - 16) / 2).astype(int)}/120 "
583-
"differences between key bytes"
584-
)
584+
"differences between key bytes")
585585
project.close(save=False)
586586
return key
587587

@@ -591,8 +591,7 @@ def parse_args():
591591
parser = argparse.ArgumentParser(
592592
description="""A distributed implementation of the attack described in
593593
"Correlation-Enhanced Power Analysis Collision Attack" by A. Moradi, O.
594-
Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf)."""
595-
)
594+
Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf).""")
596595
parser.add_argument(
597596
"-f",
598597
"--project-file",
@@ -649,8 +648,7 @@ def config_logger():
649648
sh = logging.StreamHandler()
650649
sh.setLevel(logging.INFO)
651650
formatter = logging.Formatter(
652-
"%(asctime)s %(levelname)s %(filename)s:%(lineno)d -- %(message)s"
653-
)
651+
"%(asctime)s %(levelname)s %(filename)s:%(lineno)d -- %(message)s")
654652
sh.setFormatter(formatter)
655653
logger.addHandler(sh)
656654
return logger
@@ -665,9 +663,9 @@ def main():
665663
ray.init(
666664
runtime_env={
667665
"working_dir": "../",
668-
"excludes": ["*.db", "*.cwp", "*.npy", "*.bit", "*/lfs/*", "*.pack"],
669-
}
670-
)
666+
"excludes":
667+
["*.db", "*.cwp", "*.npy", "*.bit", "*/lfs/*", "*.pack"],
668+
})
671669

672670
key = perform_attack(**vars(args))
673671
sys.exit(0 if key is not None else 1)

0 commit comments

Comments
 (0)