Skip to content

Commit 8509298

Browse files
authored
feat(bit_exact): embedding (#1465)
1 parent 9ad8f7c commit 8509298

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import numpy as np
1313
from numpy.typing import NDArray
14+
from quantizers import get_fixed_quantizer_np
1415

1516
from hls4ml.model.layers import (
1617
Activation,
@@ -22,6 +23,7 @@
2223
Dense,
2324
Einsum,
2425
EinsumDense,
26+
Embedding,
2527
GlobalPooling1D,
2628
GlobalPooling2D,
2729
Input,
@@ -654,6 +656,27 @@ def _(layer: DACombinational):
654656
return k.astype(np.int16), i.astype(np.int16), f.astype(np.int16)
655657

656658

659+
@_produce_kif.register
660+
def _(layer: Embedding):
661+
_, out_quantizers = get_output_layers_and_quantizers(layer)
662+
assert len(out_quantizers) == 1, 'Embedding layer should have exactly one consumer'
663+
quant = out_quantizers[0]
664+
k, b, i = quant.mask_kbi
665+
k, b, i = k[0], b[0], i[0]
666+
i, f = i - k, b - i
667+
k, i, f = np.max([k, i, f], axis=1) if isinstance(k, np.ndarray) else (k, i, f)
668+
quant = get_fixed_quantizer_np(quant.RND, quant.SAT)
669+
data = layer.attributes['embeddings'].data
670+
qdata = quant(data, k, i, f)
671+
layer.attributes['embeddings'].data = qdata
672+
k, i, f = minimal_kif(qdata)
673+
shape = get_output_shape(layer)
674+
k = np.broadcast_to(np.max(k, axis=0).astype(np.int16), shape)
675+
i = np.broadcast_to(np.max(i, axis=0).astype(np.int16), shape)
676+
f = np.broadcast_to(np.max(f, axis=0).astype(np.int16), shape)
677+
return k, i, f
678+
679+
657680
def kif_arrs_to_ints(arr: tuple[np.ndarray, np.ndarray, np.ndarray]):
658681
return tuple(int(np.max(a)) for a in arr)
659682

0 commit comments

Comments
 (0)