|
| 1 | +======================================== |
| 2 | +Spiking Neural Networks (PyTorch/SNN) |
| 3 | +======================================== |
| 4 | + |
| 5 | +This page describes the initial SNN support in the PyTorch frontend. |
| 6 | + |
| 7 | +Install the SNN frontend dependencies with: |
| 8 | + |
| 9 | +.. code-block:: bash |
| 10 | +
|
| 11 | + pip install hls4ml[snn] |
| 12 | +
|
| 13 | +Backend support |
| 14 | +=============== |
| 15 | + |
| 16 | +The SNN flow currently supports only the ``Vitis`` backend. |
| 17 | + |
| 18 | +Execution model |
| 19 | +=============== |
| 20 | + |
| 21 | +Current hls4ml SNN implementations are synchronous (clock-driven). Neuron state |
| 22 | +updates and layer computations run in standard HLS pipelines/streams each cycle |
| 23 | +according to interface handshakes. The generated design is not a native |
| 24 | +asynchronous/event-routed neuromorphic architecture (yet!). |
| 25 | + |
| 26 | +Reuse factor support |
| 27 | +==================== |
| 28 | + |
| 29 | +Standard hls4ml layers used inside an SNN, such as ``Dense``/linear layers, |
| 30 | +retain their normal ``ReuseFactor`` support. ``ReuseFactor`` can still be set at |
| 31 | +the model, layer type, or layer name level for these layers, and each dense layer |
| 32 | +uses its own configured value independently of the surrounding spiking neuron |
| 33 | +layers. The spiking neuron kernels themselves, ``IFNeuron`` and ``LIFNeuron``, do not |
| 34 | +currently expose ``ReuseFactor``. They process one timestep at a time, keep |
| 35 | +internal membrane state across timesteps, and unroll the per-neuron update loop |
| 36 | +across ``n_out`` channels. |
| 37 | + |
| 38 | +Supported PyTorch modules and readout wrappers |
| 39 | +============================================== |
| 40 | + |
| 41 | +The frontend currently supports direct parsing of: |
| 42 | + |
| 43 | +* ``Leaky`` -> ``LIFNeuron`` (or ``IFNeuron`` when ``beta`` is effectively 1) |
| 44 | + |
| 45 | +``SNNReadout`` is an hls4ml layer, not a ``snntorch`` module. To use the |
| 46 | +built-in hls4ml readout from a PyTorch model, instantiate the provided PyTorch |
| 47 | +marker module: |
| 48 | + |
| 49 | +.. code-block:: python |
| 50 | +
|
| 51 | + from hls4ml.contrib.snntorch import SNNReadout |
| 52 | +
|
| 53 | +The marker is an identity in PyTorch and is converted to the hls4ml |
| 54 | +``SNNReadout`` layer by the PyTorch frontend. |
| 55 | + |
| 56 | +`snntorch` tracing |
| 57 | +================== |
| 58 | + |
| 59 | +``snntorch`` modules are treated as leaf modules by the hls4ml PyTorch FX tracer. |
| 60 | +This allows conversion models to use ``snntorch.Leaky`` directly without defining |
| 61 | +conversion-only wrapper classes. |
| 62 | + |
| 63 | +For ``Leaky``, the supported reset mechanisms are: |
| 64 | + |
| 65 | +* ``subtract`` |
| 66 | +* ``zero`` |
| 67 | + |
| 68 | +``threshold`` supports scalar or per-neuron vectors (length ``n_out``) for both ``IFNeuron`` and ``LIFNeuron``. |
| 69 | +``beta`` supports scalar or per-neuron vectors for ``LIFNeuron``. |
| 70 | + |
| 71 | +Conversion selects the most memory-efficient representation automatically: |
| 72 | + |
| 73 | +* scalar values are emitted as compile-time constants |
| 74 | +* per-neuron values are emitted as parameter vectors |
| 75 | + |
| 76 | +For trainable snntorch parameters, conversion uses the current parameter values from the model |
| 77 | +at conversion time. |
| 78 | + |
| 79 | +Readout and Decision Rules |
| 80 | +========================== |
| 81 | + |
| 82 | +The hls4ml ``SNNReadout`` layer implements programmable per-model decision policies. |
| 83 | +By default, ``output_mode="spike"`` preserves the original spike-count behavior: |
| 84 | + |
| 85 | +* ``argmax_spike_count`` |
| 86 | +* ``first_to_threshold`` |
| 87 | +* ``threshold_then_argmax`` |
| 88 | +* ``binary_logit`` (for binary classifiers with ``n_classes == 2``) |
| 89 | + |
| 90 | +The layer accumulates class spikes over a window. For most decision rules it emits |
| 91 | +a class ID. For ``binary_logit``, it emits a score equal to |
| 92 | +``count(class_1) - count(class_0)``. |
| 93 | + |
| 94 | +For non-spiking readout heads, set ``output_mode="membrane"`` and connect |
| 95 | +``SNNReadout`` directly after the final dense/linear layer instead of after a |
| 96 | +final spiking neuron. In this mode the readout owns the final membrane state: |
| 97 | + |
| 98 | +.. code-block:: python |
| 99 | +
|
| 100 | + x = self.fc2(x) |
| 101 | + return self.readout(x) |
| 102 | +
|
| 103 | +At each timestep, the generated readout computes: |
| 104 | + |
| 105 | +.. code-block:: cpp |
| 106 | +
|
| 107 | + mem[i] = beta * mem[i] + input[i]; |
| 108 | +
|
| 109 | +No threshold or reset-on-spike is applied in membrane mode. The supported |
| 110 | +membrane decision policies are: |
| 111 | + |
| 112 | +* ``argmax_membrane`` |
| 113 | +* ``binary_logit`` (emits ``mem(class_1) - mem(class_0)`` for binary classifiers) |
| 114 | + |
| 115 | +This will be explained in a tutorial in the hls4ml-tutorials repo. |
| 116 | + |
| 117 | +Do not place a final spiking neuron before ``SNNReadout(output_mode="membrane")`` |
| 118 | +unless you intentionally want the readout to consume that neuron's spike output. |
| 119 | +The membrane mode does not recover or expose the internal membrane state of a |
| 120 | +preceding ``Leaky``/``IFNeuron``/``LIFNeuron`` layer. If a final output neuron |
| 121 | +has a learnable ``beta``, that learnable neuron membrane is not the same state |
| 122 | +as the readout-owned membrane. The readout uses its own scalar ``beta``. |
| 123 | + |
| 124 | +When using the default PyTorch parser, the wrapper module should expose these |
| 125 | +attributes as needed: |
| 126 | + |
| 127 | +* ``n_classes`` (defaults to the input feature count if omitted) |
| 128 | +* ``window_size`` or ``stream_length`` (defaults to ``1``) |
| 129 | +* ``class_threshold`` (defaults to ``1``) |
| 130 | +* ``output_mode`` (defaults to ``spike``; use ``membrane`` for readout-owned membrane accumulation) |
| 131 | +* ``beta`` (defaults to ``1.0`` for membrane readout) |
| 132 | +* ``decision_rule`` (defaults to ``argmax_spike_count``) |
| 133 | +* ``reset_policy`` or ``state_reset_policy`` (defaults to ``fixed_window``) |
| 134 | + |
| 135 | +Window Boundary Semantics |
| 136 | +========================= |
| 137 | + |
| 138 | +The current implementation uses ``window_size`` timesteps as the sequence boundary |
| 139 | +for generated HLS. During PyTorch conversion, the first fixed-window |
| 140 | +``SNNReadout``'s ``window_size`` is propagated to all converted ``IFNeuron`` and |
| 141 | +``LIFNeuron`` layers in the graph. |
| 142 | + |
| 143 | +At each boundary: |
| 144 | + |
| 145 | +* the class decision is emitted |
| 146 | +* internal readout counters or readout membrane state are reset for the next sequence |
| 147 | +* internal ``IFNeuron``/``LIFNeuron`` membrane state is reset for the next sequence |
| 148 | + |
| 149 | +The reset happens after the final timestep has been processed and has contributed |
| 150 | +to the output. This behavior is compatible with fixed-length time windows. |
| 151 | + |
| 152 | +Only fixed-window reset is implemented in generated layer kernels today. |
| 153 | +``state_reset_policy`` accepts future-facing values such as ``tlast``, |
| 154 | +``host_pulse``, and ``never``, but the current layer kernels still use fixed |
| 155 | +``window_size`` reset behavior. |
| 156 | + |
| 157 | +Running ``hls_model.predict()`` |
| 158 | +============================== |
| 159 | + |
| 160 | +Compiled SNN models are stateful across top-function calls. For fixed-window |
| 161 | +SNN inference, call the compiled model once per timestep and pass exactly |
| 162 | +``window_size`` timesteps for each independent sequence: |
| 163 | + |
| 164 | +.. code-block:: python |
| 165 | +
|
| 166 | + last = None |
| 167 | + for step in range(timesteps): |
| 168 | + x_step = x_sequence[step].astype("float32")[None, :] |
| 169 | + last = hls_model.predict(x_step) |
| 170 | +
|
| 171 | +After the last call in the window, generated HLS resets the neuron and readout |
| 172 | +state for the next sequence. Avoid making stray single-timestep ``predict`` |
| 173 | +calls before evaluating a sequence, because those calls advance the state. |
| 174 | + |
| 175 | +For membrane readout, the PyTorch reference should match the generated readout |
| 176 | +accumulation: |
| 177 | + |
| 178 | +.. code-block:: python |
| 179 | +
|
| 180 | + mem = torch.zeros_like(currents[:, 0, :]) |
| 181 | + for step in range(currents.shape[1]): |
| 182 | + mem = beta * mem + currents[:, step, :] |
| 183 | + pred = mem.argmax(dim=1) |
| 184 | +
|
| 185 | +Using only the final dense current, or using spike-count reduction for a |
| 186 | +membrane readout, does not match generated HLS behavior. |
| 187 | + |
| 188 | +Precision note |
| 189 | +============== |
| 190 | + |
| 191 | +Membrane readout accumulates dense currents over the full window, so very narrow |
| 192 | +fixed-point types can reduce accuracy even when the floating-point PyTorch model |
| 193 | +looks good. |
| 194 | + |
| 195 | +``TLAST`` note |
| 196 | +============== |
| 197 | + |
| 198 | +True AXI sideband ``TLAST`` boundary handling requires top-level writer/interface support for packetized AXI stream types. |
| 199 | +The current implementation does not yet expose ``TLAST`` to layer kernels directly. |
| 200 | + |
| 201 | +For variable-length windows, a practical workaround is to keep the hls4ml core unchanged and perform ``TLAST`` to boundary conversion in a thin wrapper IP around the generated project. |
0 commit comments