Skip to content

Commit 1099d65

Browse files
authored
Da custom layer (#1429)
* use metaclass for handler reg * feat: general da4ml fallback * bump da4ml version for LUT-Layer support * fix syntax err * api and docstr
1 parent f12d2cd commit 1099d65

21 files changed

Lines changed: 201 additions & 80 deletions

docs/advanced/extension.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ In this case, there a single output with the same shape as the input.
8181
8282
from hls4ml.converters.keras_v3._base import register, KerasV3LayerHandler
8383
84-
@register
84+
# Handlers are registered by metaclass
8585
class KReverseHandler(KerasV3LayerHandler):
8686
'''Keras v3 layer handler for KReverse'''
8787

hls4ml/backends/vivado/passes/distributed_arithmetic.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77

8-
from hls4ml.model.layers import Conv1D, Conv2D, Dense, EinsumDense, Layer
8+
from hls4ml.model.layers import Conv1D, Conv2D, DACombinational, Dense, EinsumDense, Layer
99
from hls4ml.model.optimizer import OptimizerPass
1010
from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers, im2col, pad_arrs, stride_arrs
1111
from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer
@@ -408,3 +408,52 @@ def transform(self, model: 'ModelGraph', node: Layer):
408408
del node.attributes['weight_data']
409409
del node.attributes['weight']
410410
del node.attributes['weight_t']
411+
412+
413+
class DACombinationalTemplate(OptimizerPass):
414+
def match(self, node):
415+
return isinstance(node, DACombinational)
416+
417+
def transform(self, model: 'ModelGraph', node: DACombinational):
418+
from da4ml.codegen.hls import hls_logic_and_bridge_gen
419+
from da4ml.trace import FixedVariableArrayInput, comb_trace
420+
421+
io_type = model.config.get_config_value('IOType')
422+
if io_type != 'io_parallel':
423+
original_type = node.attributes['original_type']
424+
raise ValueError(f'DACombinational layer (from {original_type}) only supports io_parallel.')
425+
426+
if not node.attributes.get('bit_exact_transformed', False):
427+
inp_p: FixedPrecisionType = node.get_input_variable().type.precision
428+
B, I, s = inp_p.width, inp_p.integer, inp_p.signed
429+
i, f = I - s, B - I
430+
comb = node.attributes['da_comb_logic']
431+
inp = FixedVariableArrayInput(comb.shape[0]).quantize(s, i, f)
432+
out = comb(inp)
433+
comb = comb_trace(inp, out)
434+
node.attributes['da_comb_logic'] = comb
435+
436+
comb = node.attributes['da_comb_logic']
437+
438+
backend = model.config.get_config_value('Backend').lower()
439+
if backend in ('vitis', 'vivado'):
440+
flavor = 'vitis'
441+
elif backend == 'oneapi':
442+
flavor = 'oneapi'
443+
else:
444+
raise ValueError(f'Unsupported backend {backend} for DACombinational layer.')
445+
446+
fn_name = f'da_comblogic_{node.index}'
447+
comb_logic, _ = hls_logic_and_bridge_gen(
448+
comb, fn_name, flavor=flavor, pragmas=['#pragma HLS INLINE'], print_latency=True
449+
)
450+
namespace = model.config.get_writer_config().get('Namespace', None) or 'nnet'
451+
452+
inp_t: str = node.get_input_variable().type.name
453+
out_t: str = node.get_output_variable().type.name
454+
inp_name: str = node.get_input_variable().name
455+
out_name: str = node.get_output_variable().name
456+
457+
fn_cpp = f'{namespace}::{fn_name}<{inp_t}, {out_t}>({inp_name}, {out_name});'
458+
node.attributes['da_codegen'] = Source(comb_logic)
459+
node.attributes['function_cpp'] = fn_cpp

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def _register_flows(self):
187187
'vivado:set_pipeline_style',
188188
'vivado:d_a_latency_dense_template',
189189
'vivado:d_a_latency_conv_template',
190+
'vivado:d_a_combinational_template',
190191
]
191192
vivado_types_flow = register_flow('specific_types', vivado_types, requires=[init_flow], backend=self.name)
192193

hls4ml/converters/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def convert_from_keras_model(
167167
backend='Vivado',
168168
hls_config=None,
169169
bit_exact=None,
170+
allow_da_fallback=True,
171+
allow_v2_fallback=True,
170172
**kwargs,
171173
):
172174
"""Convert Keras model to hls4ml model based on the provided configuration.
@@ -195,12 +197,17 @@ def convert_from_keras_model(
195197
io_type (str, optional): Type of implementation used. One of
196198
'io_parallel' or 'io_stream'. Defaults to 'io_parallel'.
197199
hls_config (dict, optional): The HLS config.
198-
kwargs** (dict, optional): Additional parameters that will be used to create the config of the specified backend
199200
bit_exact (bool, optional): If True, enable model-wise precision propagation
200201
with **only fixed-point data types**. If None, enable if there is at least one
201202
FixedPointQuantizer layer in the model (only resulting from converting HGQ1/2
202203
models for now). By default, None.
204+
allow_da_fallback: Whether to allow fallback to DA combinational logic generation
205+
for unsupported layers. Only affects keras v3 models. Defaults to True.
206+
allow_v2_fallback: Whether to allow fallback to keras v2 layer handlers
207+
for unsupported layers. Only affects keras v3 models. Defaults to True. If both this and
208+
`allow_da_fallback` are True, DA fallback is attempted first.
203209
210+
kwargs** (dict, optional): Additional parameters that will be used to create the config of the specified backend
204211
Raises:
205212
Exception: If precision and reuse factor are not present in 'hls_config'.
206213
@@ -227,7 +234,7 @@ def convert_from_keras_model(
227234
import keras
228235

229236
if keras.__version__ >= '3.0':
230-
return keras_v3_to_hls(config)
237+
return keras_v3_to_hls(config, allow_da_fallback, allow_v2_fallback)
231238

232239
return keras_v2_to_hls(config)
233240

hls4ml/converters/keras_v3/_base.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,16 @@ def maybe_add_attrs(config: dict[str, Any] | DefaultConfig, obj: Any, *attrs: st
3333
config[attr] = getattr(obj, attr)
3434

3535

36-
class KerasV3LayerHandler:
36+
class KerasV3LayerHandlerMeta(type):
37+
def __new__(cls, name, bases, attrs):
38+
new_class = super().__new__(cls, name, bases, attrs)
39+
if 'handles' in attrs:
40+
for handle in attrs['handles']:
41+
registry[handle] = new_class()
42+
return new_class
43+
44+
45+
class KerasV3LayerHandler(metaclass=KerasV3LayerHandlerMeta):
3746
"""Base class for keras v3 layer handlers. Subclass this class to create a handler for a specific layer type."""
3847

3948
handles = ()
@@ -160,26 +169,3 @@ def load_weight(self, layer: 'keras.Layer', key: str):
160169
import keras
161170

162171
return keras.ops.convert_to_numpy(getattr(layer, key))
163-
164-
165-
def register(cls: type):
166-
"""Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class.
167-
168-
Args:
169-
cls: the class to register the handler for.
170-
171-
Examples:
172-
```python
173-
@keras_dispatcher.register
174-
class MyLayerHandler(KerasV3LayerHandler):
175-
handles = ('my_package.src.submodule.MyLayer', 'MyLayer2')
176-
177-
def handle(self, layer, inp_tensors, out_tensors):
178-
# handler code
179-
```
180-
"""
181-
182-
fn = cls()
183-
for k in fn.handles:
184-
registry[k] = fn
185-
return cls

hls4ml/converters/keras_v3/conv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from math import ceil
44
from typing import Any
55

6-
from ._base import KerasV3LayerHandler, register
6+
from ._base import KerasV3LayerHandler
77

88
if typing.TYPE_CHECKING:
99
import keras
@@ -75,7 +75,6 @@ def gen_conv_config(
7575
return config
7676

7777

78-
@register
7978
class ConvHandler(KerasV3LayerHandler):
8079
handles = (
8180
'keras.src.layers.convolutional.conv1d.Conv1D',

hls4ml/converters/keras_v3/core.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55

66
import numpy as np
77

8-
from ._base import KerasV3LayerHandler, register
8+
from ._base import KerasV3LayerHandler
99

1010
if typing.TYPE_CHECKING:
1111
import keras
1212
from keras import KerasTensor
1313

1414

15-
@register
1615
class DenseHandler(KerasV3LayerHandler):
1716
handles = ('keras.src.layers.core.dense.Dense',)
1817

@@ -36,7 +35,6 @@ def handle(
3635
return config
3736

3837

39-
@register
4038
class InputHandler(KerasV3LayerHandler):
4139
handles = ('keras.src.layers.core.input_layer.InputLayer',)
4240

@@ -50,7 +48,6 @@ def handle(
5048
return config
5149

5250

53-
@register
5451
class ActivationHandler(KerasV3LayerHandler):
5552
handles = ('keras.src.layers.activations.activation.Activation',)
5653

@@ -89,7 +86,6 @@ def handle(
8986
return (config,)
9087

9188

92-
@register
9389
class ReLUHandler(KerasV3LayerHandler):
9490
handles = (
9591
'keras.src.layers.activations.leaky_relu.LeakyReLU',
@@ -136,7 +132,6 @@ def handle(
136132
return (config,)
137133

138134

139-
@register
140135
class SoftmaxHandler(KerasV3LayerHandler):
141136
handles = ('keras.src.layers.activations.softmax.Softmax',)
142137

@@ -169,7 +164,6 @@ def handle(
169164
return (config,)
170165

171166

172-
@register
173167
class EluHandler(KerasV3LayerHandler):
174168
handles = ('keras.src.layers.activations.elu.ELU',)
175169

@@ -190,7 +184,6 @@ def handle(
190184
return (config,)
191185

192186

193-
@register
194187
class ReshapeHandler(KerasV3LayerHandler):
195188
handles = ('keras.src.layers.reshaping.reshape.Reshape', 'keras.src.layers.reshaping.flatten.Flatten')
196189

@@ -206,7 +199,6 @@ def handle(
206199
}
207200

208201

209-
@register
210202
class PermuteHandler(KerasV3LayerHandler):
211203
handles = ('keras.src.layers.reshaping.permute.Permute',)
212204

@@ -220,7 +212,6 @@ def handle(
220212
return config
221213

222214

223-
@register
224215
class NoOp(KerasV3LayerHandler):
225216
handles = (
226217
'keras.src.layers.preprocessing.image_preprocessing.random_brightness.RandomBrightness',

hls4ml/converters/keras_v3/einsum_dense.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import typing
22
from collections.abc import Sequence
33

4-
from ._base import KerasV3LayerHandler, register
4+
from ._base import KerasV3LayerHandler
55

66
if typing.TYPE_CHECKING:
77
import keras
@@ -41,7 +41,6 @@ def strip_batch_dim(equation: str, einsum_dense: bool = True):
4141
return f'{inp0},{inp1}->{out}'
4242

4343

44-
@register
4544
class EinsumDenseHandler(KerasV3LayerHandler):
4645
handles = ('keras.src.layers.core.einsum_dense.EinsumDense',)
4746

hls4ml/converters/keras_v3/hgq2/_base.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from hls4ml.converters.keras_v3._base import KerasV3LayerHandler, register
7+
from hls4ml.converters.keras_v3._base import KerasV3LayerHandler
88
from hls4ml.converters.keras_v3.conv import ConvHandler
99
from hls4ml.converters.keras_v3.core import ActivationHandler, DenseHandler
1010
from hls4ml.converters.keras_v3.einsum_dense import EinsumDenseHandler
@@ -67,7 +67,6 @@ def override_io_tensor_confs(confs: tuple[dict[str, Any], ...], overrides: dict[
6767
conf['output_keras_tensor_names'] = [overrides.get(name, name) for name in out_tensor_names]
6868

6969

70-
@register
7170
class QLayerHandler(KerasV3LayerHandler):
7271
def __call__(
7372
self,
@@ -121,15 +120,13 @@ def default_class_name(self, layer: 'Layer') -> str:
121120
return class_name
122121

123122

124-
@register
125123
class QEinsumDenseHandler(QLayerHandler, EinsumDenseHandler):
126124
handles = (
127125
'hgq.layers.core.einsum_dense.QEinsumDense',
128126
'hgq.layers.einsum_dense_batchnorm.QEinsumDenseBatchnorm',
129127
)
130128

131129

132-
@register
133130
class QStandaloneQuantizerHandler(KerasV3LayerHandler):
134131
handles = ('hgq.quantizer.quantizer.Quantizer',)
135132

@@ -144,7 +141,6 @@ def handle(
144141
return conf
145142

146143

147-
@register
148144
class QConvHandler(QLayerHandler, ConvHandler):
149145
handles = (
150146
'hgq.layers.conv.QConv1D',
@@ -170,7 +166,6 @@ def handle(
170166
return conf
171167

172168

173-
@register
174169
class QDenseHandler(QLayerHandler, DenseHandler):
175170
handles = ('hgq.layers.core.dense.QDense', 'hgq.layers.core.dense.QBatchNormDense')
176171

@@ -189,12 +184,10 @@ def handle(
189184
return conf
190185

191186

192-
@register
193187
class QActivationHandler(QLayerHandler, ActivationHandler):
194188
handles = ('hgq.layers.activation.QActivation',)
195189

196190

197-
@register
198191
class QBatchNormalizationHandler(QLayerHandler):
199192
handles = ('hgq.layers.batch_normalization.QBatchNormalization',)
200193

@@ -220,7 +213,6 @@ def handle(
220213
}
221214

222215

223-
@register
224216
class QMergeHandler(QLayerHandler, MergeHandler):
225217
handles = (
226218
'hgq.layers.ops.merge.QAdd',

hls4ml/converters/keras_v3/hgq2/einsum.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
from collections.abc import Sequence
33

44
from ..einsum_dense import strip_batch_dim
5-
from ._base import QLayerHandler, register
5+
from ._base import QLayerHandler
66

77
if typing.TYPE_CHECKING:
88
import hgq
99
from keras import KerasTensor
1010

1111

12-
@register
1312
class QEinsumHandler(QLayerHandler):
1413
handles = ('hgq.layers.ops.einsum.QEinsum',)
1514

0 commit comments

Comments
 (0)