Skip to content

Commit 1ca12f1

Browse files
Basic SNN functionality in hls4ml (mapped from snntorch) (#1470)
* Add SNN work and ignore local notebook/test artifacts * jupyter notebook demo of snn functionality * updates to readme * Update README.md * add learnable beta and threhold functionality * update docs * readout and reset updated * update branch for pr * docs update * added comments * notebook update * pr bug fix * [pre-commit.ci] auto fixes from pre-commit hooks * remove notebook * refactor readout * update to SNNReadout attributes * Added backend support to docs * updated vivado->vitis in tests * update docs on RF support for spiking neurons * snn streaming moved to nnet_snn_stream.h * updated snn readout layer defaults * moved gettattr * snn window parsing moved to optimizer pass * reverted changes * snnreadout moved from utils to contrib * updated mentions of tutorial notebook * [pre-commit.ci] auto fixes from pre-commit hooks * add __contains__ to handle scalar spiking neuron attributes (threshold & beta) * test updates for snn * updated snnreadout defaults * update config conversion * layer attribute updates for keras and onnx --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 10dbeae commit 1ca12f1

18 files changed

Lines changed: 1622 additions & 3 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ docs/_build
1313
docs/autodoc/*
1414
hls4mlprj_*
1515
*~
16+
*.ipynb
1617
*.ipynb_checkpoints/
1718
*.bak

docs/advanced/snn.rst

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
========================================
2+
Spiking Neural Networks (PyTorch/SNN)
3+
========================================
4+
5+
This page describes the initial SNN support in the PyTorch frontend.
6+
7+
Install the SNN frontend dependencies with:
8+
9+
.. code-block:: bash
10+
11+
pip install hls4ml[snn]
12+
13+
Backend support
14+
===============
15+
16+
The SNN flow currently supports only the ``Vitis`` backend.
17+
18+
Execution model
19+
===============
20+
21+
Current hls4ml SNN implementations are synchronous (clock-driven). Neuron state
22+
updates and layer computations run in standard HLS pipelines/streams each cycle
23+
according to interface handshakes. The generated design is not a native
24+
asynchronous/event-routed neuromorphic architecture (yet!).
25+
26+
Reuse factor support
27+
====================
28+
29+
Standard hls4ml layers used inside an SNN, such as ``Dense``/linear layers,
30+
retain their normal ``ReuseFactor`` support. ``ReuseFactor`` can still be set at
31+
the model, layer type, or layer name level for these layers, and each dense layer
32+
uses its own configured value independently of the surrounding spiking neuron
33+
layers. The spiking neuron kernels themselves, ``IFNeuron`` and ``LIFNeuron``, do not
34+
currently expose ``ReuseFactor``. They process one timestep at a time, keep
35+
internal membrane state across timesteps, and unroll the per-neuron update loop
36+
across ``n_out`` channels.
37+
38+
Supported PyTorch modules and readout wrappers
39+
==============================================
40+
41+
The frontend currently supports direct parsing of:
42+
43+
* ``Leaky`` -> ``LIFNeuron`` (or ``IFNeuron`` when ``beta`` is effectively 1)
44+
45+
``SNNReadout`` is an hls4ml layer, not a ``snntorch`` module. To use the
46+
built-in hls4ml readout from a PyTorch model, instantiate the provided PyTorch
47+
marker module:
48+
49+
.. code-block:: python
50+
51+
from hls4ml.contrib.snntorch import SNNReadout
52+
53+
The marker is an identity in PyTorch and is converted to the hls4ml
54+
``SNNReadout`` layer by the PyTorch frontend.
55+
56+
`snntorch` tracing
57+
==================
58+
59+
``snntorch`` modules are treated as leaf modules by the hls4ml PyTorch FX tracer.
60+
This allows conversion models to use ``snntorch.Leaky`` directly without defining
61+
conversion-only wrapper classes.
62+
63+
For ``Leaky``, the supported reset mechanisms are:
64+
65+
* ``subtract``
66+
* ``zero``
67+
68+
``threshold`` supports scalar or per-neuron vectors (length ``n_out``) for both ``IFNeuron`` and ``LIFNeuron``.
69+
``beta`` supports scalar or per-neuron vectors for ``LIFNeuron``.
70+
71+
Conversion selects the most memory-efficient representation automatically:
72+
73+
* scalar values are emitted as compile-time constants
74+
* per-neuron values are emitted as parameter vectors
75+
76+
For trainable snntorch parameters, conversion uses the current parameter values from the model
77+
at conversion time.
78+
79+
Readout and Decision Rules
80+
==========================
81+
82+
The hls4ml ``SNNReadout`` layer implements programmable per-model decision policies.
83+
By default, ``output_mode="spike"`` preserves the original spike-count behavior:
84+
85+
* ``argmax_spike_count``
86+
* ``first_to_threshold``
87+
* ``threshold_then_argmax``
88+
* ``binary_logit`` (for binary classifiers with ``n_classes == 2``)
89+
90+
The layer accumulates class spikes over a window. For most decision rules it emits
91+
a class ID. For ``binary_logit``, it emits a score equal to
92+
``count(class_1) - count(class_0)``.
93+
94+
For non-spiking readout heads, set ``output_mode="membrane"`` and connect
95+
``SNNReadout`` directly after the final dense/linear layer instead of after a
96+
final spiking neuron. In this mode the readout owns the final membrane state:
97+
98+
.. code-block:: python
99+
100+
x = self.fc2(x)
101+
return self.readout(x)
102+
103+
At each timestep, the generated readout computes:
104+
105+
.. code-block:: cpp
106+
107+
mem[i] = beta * mem[i] + input[i];
108+
109+
No threshold or reset-on-spike is applied in membrane mode. The supported
110+
membrane decision policies are:
111+
112+
* ``argmax_membrane``
113+
* ``binary_logit`` (emits ``mem(class_1) - mem(class_0)`` for binary classifiers)
114+
115+
This will be explained in a tutorial in the hls4ml-tutorials repo.
116+
117+
Do not place a final spiking neuron before ``SNNReadout(output_mode="membrane")``
118+
unless you intentionally want the readout to consume that neuron's spike output.
119+
The membrane mode does not recover or expose the internal membrane state of a
120+
preceding ``Leaky``/``IFNeuron``/``LIFNeuron`` layer. If a final output neuron
121+
has a learnable ``beta``, that learnable neuron membrane is not the same state
122+
as the readout-owned membrane. The readout uses its own scalar ``beta``.
123+
124+
When using the default PyTorch parser, the wrapper module should expose these
125+
attributes as needed:
126+
127+
* ``n_classes`` (defaults to the input feature count if omitted)
128+
* ``window_size`` or ``stream_length`` (defaults to ``1``)
129+
* ``class_threshold`` (defaults to ``1``)
130+
* ``output_mode`` (defaults to ``spike``; use ``membrane`` for readout-owned membrane accumulation)
131+
* ``beta`` (defaults to ``1.0`` for membrane readout)
132+
* ``decision_rule`` (defaults to ``argmax_spike_count``)
133+
* ``reset_policy`` or ``state_reset_policy`` (defaults to ``fixed_window``)
134+
135+
Window Boundary Semantics
136+
=========================
137+
138+
The current implementation uses ``window_size`` timesteps as the sequence boundary
139+
for generated HLS. During PyTorch conversion, the first fixed-window
140+
``SNNReadout``'s ``window_size`` is propagated to all converted ``IFNeuron`` and
141+
``LIFNeuron`` layers in the graph.
142+
143+
At each boundary:
144+
145+
* the class decision is emitted
146+
* internal readout counters or readout membrane state are reset for the next sequence
147+
* internal ``IFNeuron``/``LIFNeuron`` membrane state is reset for the next sequence
148+
149+
The reset happens after the final timestep has been processed and has contributed
150+
to the output. This behavior is compatible with fixed-length time windows.
151+
152+
Only fixed-window reset is implemented in generated layer kernels today.
153+
``state_reset_policy`` accepts future-facing values such as ``tlast``,
154+
``host_pulse``, and ``never``, but the current layer kernels still use fixed
155+
``window_size`` reset behavior.
156+
157+
Running ``hls_model.predict()``
158+
==============================
159+
160+
Compiled SNN models are stateful across top-function calls. For fixed-window
161+
SNN inference, call the compiled model once per timestep and pass exactly
162+
``window_size`` timesteps for each independent sequence:
163+
164+
.. code-block:: python
165+
166+
last = None
167+
for step in range(timesteps):
168+
x_step = x_sequence[step].astype("float32")[None, :]
169+
last = hls_model.predict(x_step)
170+
171+
After the last call in the window, generated HLS resets the neuron and readout
172+
state for the next sequence. Avoid making stray single-timestep ``predict``
173+
calls before evaluating a sequence, because those calls advance the state.
174+
175+
For membrane readout, the PyTorch reference should match the generated readout
176+
accumulation:
177+
178+
.. code-block:: python
179+
180+
mem = torch.zeros_like(currents[:, 0, :])
181+
for step in range(currents.shape[1]):
182+
mem = beta * mem + currents[:, step, :]
183+
pred = mem.argmax(dim=1)
184+
185+
Using only the final dense current, or using spike-count reduction for a
186+
membrane readout, does not match generated HLS behavior.
187+
188+
Precision note
189+
==============
190+
191+
Membrane readout accumulates dense currents over the full window, so very narrow
192+
fixed-point types can reduce accuracy even when the floating-point PyTorch model
193+
looks good.
194+
195+
``TLAST`` note
196+
==============
197+
198+
True AXI sideband ``TLAST`` boundary handling requires top-level writer/interface support for packetized AXI stream types.
199+
The current implementation does not yet expose ``TLAST`` to layer kernels directly.
200+
201+
For variable-length windows, a practical workaround is to keep the hls4ml core unchanged and perform ``TLAST`` to boundary conversion in a thin wrapper IP around the generated project.

docs/frontend/pytorch.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ of the model. If the ``io_parallel`` I/O type (see :ref:`Concepts`) is used, a t
1818
Outputs are not transposed back by default, but in ``io_parallel`` case, a transpose node can be added. If not needed, these adjustments can also be switched off. See :py:class:`~hls4ml.utils.config.config_from_pytorch_model` for details.
1919

2020
The equivalent of Keras extension API is not yet available for PyTorch parser, and will be provided in the future.
21+
22+
.. note::
23+
Experimental spiking layer support is available for selected modules. See :doc:`../advanced/snn` for details.

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
advanced/precision
5252
advanced/fifo_depth
5353
advanced/extension
54+
advanced/snn
5455
advanced/model_optimization
5556
advanced/bramfactor
5657
advanced/plugins
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
2+
from hls4ml.model.layers import IFNeuron, LIFNeuron, SNNReadout
3+
4+
if_config_template = """struct config{index} : nnet::if_neuron_config {{
5+
static const unsigned n_in = {n_in};
6+
static const unsigned n_out = {n_out};
7+
static const unsigned io_type = nnet::{iotype};
8+
static const unsigned window_size = {window_size};
9+
static const bool threshold_is_vector = {threshold_is_vector};
10+
static constexpr float threshold = {threshold};
11+
static const nnet::snn_reset_mode reset_mode = nnet::snn_reset_mode::{reset_mechanism};
12+
typedef {threshold_t.name} threshold_t;
13+
typedef {membrane_t.name} membrane_t;
14+
}};\n"""
15+
16+
if_function_template = 'nnet::if_neuron<{input_t}, {output_t}, {config}>({input}, {output}, {threshold});'
17+
snn_include_list = ['nnet_utils/nnet_snn.h', 'nnet_utils/nnet_snn_stream.h']
18+
19+
20+
class IFNeuronConfigTemplate(LayerConfigTemplate):
21+
def __init__(self):
22+
super().__init__(IFNeuron)
23+
self.template = if_config_template
24+
25+
def format(self, node):
26+
params = self._default_config_params(node)
27+
params['threshold_is_vector'] = 'true' if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'false'
28+
return self.template.format(**params)
29+
30+
31+
class IFNeuronFunctionTemplate(FunctionCallTemplate):
32+
def __init__(self):
33+
super().__init__(IFNeuron, include_header=snn_include_list)
34+
self.template = if_function_template
35+
36+
def format(self, node):
37+
params = self._default_function_params(node)
38+
params['threshold'] = (
39+
node.get_weights('threshold_vec').name if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'nullptr'
40+
)
41+
return self.template.format(**params)
42+
43+
44+
lif_config_template = """struct config{index} : nnet::lif_neuron_config {{
45+
static const unsigned n_in = {n_in};
46+
static const unsigned n_out = {n_out};
47+
static const unsigned io_type = nnet::{iotype};
48+
static const unsigned window_size = {window_size};
49+
static const bool beta_is_vector = {beta_is_vector};
50+
static const bool threshold_is_vector = {threshold_is_vector};
51+
static constexpr float threshold = {threshold};
52+
static constexpr float beta = {beta};
53+
static const nnet::snn_reset_mode reset_mode = nnet::snn_reset_mode::{reset_mechanism};
54+
typedef {beta_t.name} beta_t;
55+
typedef {threshold_t.name} threshold_t;
56+
typedef {membrane_t.name} membrane_t;
57+
}};\n"""
58+
59+
lif_function_template = 'nnet::lif_neuron<{input_t}, {output_t}, {config}>({input}, {output}, {beta}, {threshold});'
60+
61+
62+
class LIFNeuronConfigTemplate(LayerConfigTemplate):
63+
def __init__(self):
64+
super().__init__(LIFNeuron)
65+
self.template = lif_config_template
66+
67+
def format(self, node):
68+
params = self._default_config_params(node)
69+
params['beta_is_vector'] = 'true' if node.get_attr('beta_mode', 'scalar') == 'vector' else 'false'
70+
params['threshold_is_vector'] = 'true' if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'false'
71+
return self.template.format(**params)
72+
73+
74+
class LIFNeuronFunctionTemplate(FunctionCallTemplate):
75+
def __init__(self):
76+
super().__init__(LIFNeuron, include_header=snn_include_list)
77+
self.template = lif_function_template
78+
79+
def format(self, node):
80+
params = self._default_function_params(node)
81+
params['beta'] = node.get_weights('beta_vec').name if node.get_attr('beta_mode', 'scalar') == 'vector' else 'nullptr'
82+
params['threshold'] = (
83+
node.get_weights('threshold_vec').name if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'nullptr'
84+
)
85+
return self.template.format(**params)
86+
87+
88+
readout_config_template = """struct config{index} : nnet::snn_readout_config {{
89+
static const unsigned n_classes = {n_classes};
90+
static const unsigned io_type = nnet::{iotype};
91+
static const unsigned window_size = {window_size};
92+
static const unsigned class_threshold = {class_threshold};
93+
static constexpr float beta = {beta};
94+
static const nnet::snn_readout_mode output_mode = nnet::snn_readout_mode::{output_mode};
95+
static const nnet::snn_decision_rule decision_rule = nnet::snn_decision_rule::{decision_rule};
96+
typedef {membrane_t.name} membrane_t;
97+
}};\n"""
98+
99+
readout_function_template = 'nnet::snn_readout<{input_t}, {output_t}, {config}>({input}, {output});'
100+
101+
102+
class SNNReadoutConfigTemplate(LayerConfigTemplate):
103+
def __init__(self):
104+
super().__init__(SNNReadout)
105+
self.template = readout_config_template
106+
107+
def format(self, node):
108+
params = self._default_config_params(node)
109+
return self.template.format(**params)
110+
111+
112+
class SNNReadoutFunctionTemplate(FunctionCallTemplate):
113+
def __init__(self):
114+
super().__init__(SNNReadout, include_header=snn_include_list)
115+
self.template = readout_function_template
116+
117+
def format(self, node):
118+
params = self._default_function_params(node)
119+
return self.template.format(**params)
120+
121+
122+
def register_snn_templates(backend):
123+
backend.register_template(IFNeuronConfigTemplate)
124+
backend.register_template(IFNeuronFunctionTemplate)
125+
backend.register_template(LIFNeuronConfigTemplate)
126+
backend.register_template(LIFNeuronFunctionTemplate)
127+
backend.register_template(SNNReadoutConfigTemplate)
128+
backend.register_template(SNNReadoutFunctionTemplate)

0 commit comments

Comments
 (0)