Skip to content

Commit 5aceb2e

Browse files
committed
Docs
1 parent e9dfcbf commit 5aceb2e

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

docs/source/guide/guide_part_i.rst

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,53 @@ their simulation separately at the temporal granularity of chosen :code:`dt`, in
375375
This is a strict departure from the computation of *deep neural networks* (DNNs), in which an ordering of layers is
376376
supposed, and layers' activations are computed *in sequence* from the shallowest to the deepest layer in a single time
377377
step, with the exclusion of recurrent layers, whose computations are still ordered in time.
378+
379+
380+
Lowering precision
381+
------------------
382+
383+
You can choose the precision for the weights.
384+
It can be specified as the :code:`value_dtype` parameter of the Weight class.
385+
386+
.. code-block:: python
387+
388+
MulticompartmentConnection(
389+
...
390+
pipeline=[
391+
Weight(
392+
'weight',
393+
w,
394+
value_dtype='float16',
395+
...
396+
)
397+
]
398+
)
399+
400+
Below is the performance statistics for float16 and float32.
401+
402+
The data was obtained by running examples/benchmark/lowering_precision.py
403+
404+
405+
.. code-block:: text
406+
407+
precision: float32
408+
Time (sec) | GPU memory (Mb)
409+
19.7812 | 52
410+
19.4812 | 52
411+
19.0769 | 52
412+
19.1530 | 52
413+
Average time: 19.373075
414+
Average memory: 52.0
415+
416+
precision: float16
417+
Time (sec) | GPU memory (Mb)
418+
19.5023 | 49
419+
20.5734 | 49
420+
19.8735 | 49
421+
19.8931 | 49
422+
Average time: 19.960575
423+
Average memory: 49.0
424+
425+
426+
As you can see, reducing from float32 to float16 does not provide a significant advantage in terms of time or memory.
427+
The float16 option only reduces memory usage by 6%.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import re
2+
import os
3+
import subprocess
4+
from statistics import mean
5+
6+
precision_sample_size = 4
7+
precisions = ['float16', 'float32']
8+
9+
folder = os.path.dirname(os.path.dirname(__file__))
10+
script = os.path.join(folder, 'mnist', 'batch_eth_mnist.py')
11+
data = {}
12+
for precision in precisions:
13+
for _ in range(precision_sample_size):
14+
result = subprocess.run(
15+
f"python {script} --n_train 100 --batch_size 50 --n_test 10 --n_updates 1 --w_dtype {precision}",
16+
shell=True, capture_output=True, text=True
17+
)
18+
output = result.stdout
19+
time_match = re.search(r'Progress: 1 / 1 \((\d+\.\d+) seconds\)', output)
20+
memory_match = re.search(r'Memory consumption: (\d+)mb', output)
21+
data.setdefault(precision, []).append([
22+
time_match.groups()[0],
23+
memory_match.groups()[0]
24+
])
25+
print("+")
26+
27+
28+
def print_table(data):
29+
column_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
30+
for row in data:
31+
formatted_row = " | ".join(f"{str(item):<{column_widths[i]}}" for i, item in enumerate(row))
32+
print(formatted_row)
33+
34+
35+
average_time = {}
36+
average_memory = {}
37+
for precision, rows in data.items():
38+
print(f"precision: {precision}")
39+
table = [
40+
['Time (sec)', 'GPU memory (Mb)']
41+
] + rows
42+
avg_time = mean(map(lambda i: float(i[0]), rows))
43+
avg_memory = mean(map(lambda i: float(i[1]), rows))
44+
print_table(table)
45+
print(f"Average time: {avg_time}")
46+
print(f"Average memory: {avg_memory}")
47+
average_memory[precision] = avg_memory
48+
average_time[precision] = avg_time
49+
print('')

examples/mnist/batch_eth_mnist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@
384384

385385
print("\nAll activity accuracy: %.2f" % (accuracy["all"] / n_test))
386386
print("Proportion weighting accuracy: %.2f \n" % (accuracy["proportion"] / n_test))
387+
print(f"Memory consumption: {round(torch.cuda.max_memory_allocated(device=None) / 1024 ** 2)}mb")
387388

388389
print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
389390
print("\nTesting complete.\n")

0 commit comments

Comments
 (0)