|
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 |
|
8 | | -from hls4ml.model.layers import Conv1D, Conv2D, Dense, EinsumDense, Layer |
| 8 | +from hls4ml.model.layers import Conv1D, Conv2D, DACombinational, Dense, EinsumDense, Layer |
9 | 9 | from hls4ml.model.optimizer import OptimizerPass |
10 | 10 | from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers, im2col, pad_arrs, stride_arrs |
11 | 11 | from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer |
@@ -408,3 +408,52 @@ def transform(self, model: 'ModelGraph', node: Layer): |
408 | 408 | del node.attributes['weight_data'] |
409 | 409 | del node.attributes['weight'] |
410 | 410 | del node.attributes['weight_t'] |
| 411 | + |
| 412 | + |
| 413 | +class DACombinationalTemplate(OptimizerPass): |
| 414 | + def match(self, node): |
| 415 | + return isinstance(node, DACombinational) |
| 416 | + |
| 417 | + def transform(self, model: 'ModelGraph', node: DACombinational): |
| 418 | + from da4ml.codegen.hls import hls_logic_and_bridge_gen |
| 419 | + from da4ml.trace import FixedVariableArrayInput, comb_trace |
| 420 | + |
| 421 | + io_type = model.config.get_config_value('IOType') |
| 422 | + if io_type != 'io_parallel': |
| 423 | + original_type = node.attributes['original_type'] |
| 424 | + raise ValueError(f'DACombinational layer (from {original_type}) only supports io_parallel.') |
| 425 | + |
| 426 | + if not node.attributes.get('bit_exact_transformed', False): |
| 427 | + inp_p: FixedPrecisionType = node.get_input_variable().type.precision |
| 428 | + B, I, s = inp_p.width, inp_p.integer, inp_p.signed |
| 429 | + i, f = I - s, B - I |
| 430 | + comb = node.attributes['da_comb_logic'] |
| 431 | + inp = FixedVariableArrayInput(comb.shape[0]).quantize(s, i, f) |
| 432 | + out = comb(inp) |
| 433 | + comb = comb_trace(inp, out) |
| 434 | + node.attributes['da_comb_logic'] = comb |
| 435 | + |
| 436 | + comb = node.attributes['da_comb_logic'] |
| 437 | + |
| 438 | + backend = model.config.get_config_value('Backend').lower() |
| 439 | + if backend in ('vitis', 'vivado'): |
| 440 | + flavor = 'vitis' |
| 441 | + elif backend == 'oneapi': |
| 442 | + flavor = 'oneapi' |
| 443 | + else: |
| 444 | + raise ValueError(f'Unsupported backend {backend} for DACombinational layer.') |
| 445 | + |
| 446 | + fn_name = f'da_comblogic_{node.index}' |
| 447 | + comb_logic, _ = hls_logic_and_bridge_gen( |
| 448 | + comb, fn_name, flavor=flavor, pragmas=['#pragma HLS INLINE'], print_latency=True |
| 449 | + ) |
| 450 | + namespace = model.config.get_writer_config().get('Namespace', None) or 'nnet' |
| 451 | + |
| 452 | + inp_t: str = node.get_input_variable().type.name |
| 453 | + out_t: str = node.get_output_variable().type.name |
| 454 | + inp_name: str = node.get_input_variable().name |
| 455 | + out_name: str = node.get_output_variable().name |
| 456 | + |
| 457 | + fn_cpp = f'{namespace}::{fn_name}<{inp_t}, {out_t}>({inp_name}, {out_name});' |
| 458 | + node.attributes['da_codegen'] = Source(comb_logic) |
| 459 | + node.attributes['function_cpp'] = fn_cpp |
0 commit comments