Skip to content

Commit 158f4e8

Browse files
authored
Add alkaid interface (#45)
* initial alkaid interface * Refactored PQ MHA quantizer flow, added HGQ style quantized Softmax * Added an upper limit to integer bits in dynamic data quantization, could previously go over total bitwidth
1 parent 734ab81 commit 158f4e8

15 files changed

Lines changed: 2089 additions & 240 deletions

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ optional-dependencies.test = [ "pytest>=8.4" ]
3030
optional-dependencies.torch = [ "torch>=2.1" ]
3131
urls.repository = "https://github.com/cern-nextgen/PQuantML"
3232

33+
entry-points."alkaid_keras".pquant = "pquant._alkaid_plugin._alkaid_keras_plugin:register"
34+
entry-points."alkaid_torch".pquant = "pquant._alkaid_plugin._alkaid_torch_plugin:register"
35+
3336
[tool.setuptools]
3437
packages = [ "pquant" ]
3538
include-package-data = true

src/pquant/_alkaid_plugin/__init__.py

Whitespace-only changes.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import numpy as np
6+
from alkaid.trace.ops import quantize as alkaid_quantize
7+
8+
9+
class PQuantAlkaidError(ValueError):
10+
"""Raised for PQuant states that cannot be replayed by Alkaid."""
11+
12+
13+
def to_numpy(value: Any) -> np.ndarray:
14+
if value is None:
15+
return np.array(0.0)
16+
if isinstance(value, np.ndarray):
17+
return value
18+
if hasattr(value, 'detach'):
19+
value = value.detach()
20+
if hasattr(value, 'cpu'):
21+
value = value.cpu()
22+
return value.numpy()
23+
try:
24+
import keras
25+
26+
return np.asarray(keras.ops.convert_to_numpy(value))
27+
except Exception:
28+
return np.asarray(value)
29+
30+
31+
def to_bool(value: Any, default: bool = False) -> bool:
32+
if value is None:
33+
return default
34+
try:
35+
arr = to_numpy(value)
36+
except Exception:
37+
return bool(value)
38+
if arr.shape == ():
39+
return bool(arr.item())
40+
return bool(np.all(arr))
41+
42+
43+
def to_int_bits(value: Any) -> np.ndarray:
44+
return np.rint(to_numpy(value)).astype(np.int64)
45+
46+
47+
def raw_module_attr(obj: Any, name: str, default: Any = None) -> Any:
48+
for storage_name in ('_parameters', '_buffers', '_modules'):
49+
storage = getattr(obj, storage_name, None)
50+
if isinstance(storage, dict) and name in storage:
51+
return storage[name]
52+
try:
53+
return object.__getattribute__(obj, name)
54+
except AttributeError:
55+
return getattr(obj, name, default)
56+
57+
58+
def quantizer_kif(quantizer: Any) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
59+
if hasattr(quantizer, '_parameters'):
60+
if not bool(raw_module_attr(quantizer, 'use_hgq', False)):
61+
return (
62+
to_int_bits(raw_module_attr(quantizer, 'k')),
63+
to_int_bits(raw_module_attr(quantizer, 'i')),
64+
to_int_bits(raw_module_attr(quantizer, 'f')),
65+
)
66+
inner = raw_module_attr(quantizer, 'quantizer')
67+
if hasattr(inner, '_parameters') or hasattr(inner, '_buffers'):
68+
k = raw_module_attr(inner, '_k')
69+
i = raw_module_attr(inner, '_i_raw', None)
70+
if i is None:
71+
i = raw_module_attr(inner, '_i')
72+
f = raw_module_attr(inner, '_f')
73+
return to_int_bits(k), to_int_bits(i), to_int_bits(f)
74+
k, i, f = quantizer.get_quantization_bits()
75+
return to_int_bits(k), to_int_bits(i), to_int_bits(f)
76+
77+
78+
def replay_quantizer(quantizer: Any, x: Any) -> Any:
79+
k, i, f = quantizer_kif(quantizer)
80+
inner = raw_module_attr(quantizer, 'quantizer', None)
81+
overflow = raw_module_attr(quantizer, 'overflow', raw_module_attr(inner, 'overflow_mode', 'WRAP'))
82+
round_mode = raw_module_attr(quantizer, 'round_mode', raw_module_attr(inner, 'round_mode', 'TRN'))
83+
return alkaid_quantize(x, k=k, i=i, f=f, overflow_mode=str(overflow).upper(), round_mode=str(round_mode).upper())
84+
85+
86+
def replay_quantizer_if_enabled(layer: Any, quantizer_name: str, x: Any, flag_name: str) -> Any:
87+
if not bool(getattr(layer, 'enable_quantization', True)):
88+
return x
89+
if not bool(getattr(layer, flag_name, True)):
90+
return x
91+
quantizer = getattr(layer, quantizer_name, None)
92+
return replay_quantizer(quantizer, x)
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
from __future__ import annotations
2+
3+
from math import prod
4+
5+
import keras
6+
import numpy as np
7+
from alkaid.converter.builtin.keras.layers._base import ReplayOperationBase
8+
from alkaid.converter.builtin.keras.layers.activation import keras_numpy_unary_map
9+
from alkaid.converter.builtin.keras.layers.batchnorm import ReplayBatchNormalization
10+
from alkaid.converter.builtin.keras.layers.conv import _conv
11+
from alkaid.converter.builtin.keras.layers.pool import ReplayPool
12+
from alkaid.trace import FVArray
13+
from alkaid.trace.ops import einsum, extract_patches
14+
from keras.layers import DepthwiseConv1D, DepthwiseConv2D
15+
16+
from pquant._alkaid_plugin._alkaid_common import (
17+
PQuantAlkaidError,
18+
replay_quantizer,
19+
replay_quantizer_if_enabled,
20+
to_bool,
21+
to_numpy,
22+
)
23+
from pquant.core.keras.activations import PQActivation
24+
from pquant.core.keras.layers import (
25+
PQAvgPool1d,
26+
PQAvgPool2d,
27+
PQBatchNormalization,
28+
PQConv1d,
29+
PQConv2d,
30+
PQDense,
31+
PQDepthwiseConv2d,
32+
PQMultiheadAttention,
33+
PQSeparableConv2d,
34+
PQSoftmax,
35+
)
36+
from pquant.core.keras.quantizer import Quantizer
37+
38+
39+
def _assert_final_compression(layer) -> None:
40+
if not to_bool(getattr(layer, 'final_compression_done', False)):
41+
raise PQuantAlkaidError(
42+
f'{layer.__class__.__name__} must have apply_final_compression() applied before Alkaid conversion.'
43+
)
44+
45+
46+
def _weight(layer) -> np.ndarray:
47+
_assert_final_compression(layer)
48+
return to_numpy(layer._kernel)
49+
50+
51+
def _bias(layer) -> np.ndarray:
52+
_assert_final_compression(layer)
53+
bias = getattr(layer, '_bias', None)
54+
if bias is None:
55+
return np.array(0.0)
56+
return to_numpy(bias)
57+
58+
59+
class ReplayPQuantQuantizer(ReplayOperationBase):
60+
__activation_handled__ = True
61+
handles = (Quantizer,)
62+
63+
def call(self, x: FVArray) -> FVArray:
64+
return replay_quantizer(self.op, x)
65+
66+
67+
class ReplayPQuantDense(ReplayOperationBase):
68+
handles = (PQDense,)
69+
70+
def call(self, inputs: FVArray) -> FVArray:
71+
layer = self.op
72+
inputs = replay_quantizer_if_enabled(layer, 'input_quantizer', inputs, 'quantize_input')
73+
out = np.einsum('...c,cC->...C', inputs, _weight(layer)) + _bias(layer)
74+
return replay_quantizer_if_enabled(layer, 'output_quantizer', out, 'quantize_output')
75+
76+
77+
class ReplayPQuantConv(ReplayOperationBase):
78+
handles = (PQConv1d, PQConv2d, PQDepthwiseConv2d)
79+
80+
def call(self, inputs: FVArray) -> FVArray:
81+
layer = self.op
82+
inputs = replay_quantizer_if_enabled(layer, 'input_quantizer', inputs, 'quantize_input')
83+
kernel = _weight(layer)
84+
bias = _bias(layer)
85+
86+
if isinstance(layer, (DepthwiseConv1D, DepthwiseConv2D)):
87+
ch_in, dm = kernel.shape[-2:]
88+
kernel = kernel.reshape(*kernel.shape[:-2], 1, ch_in * dm)
89+
groups = ch_in
90+
else:
91+
groups = layer.groups
92+
93+
x = extract_patches(
94+
inputs,
95+
size=layer.kernel_size,
96+
strides=layer.strides,
97+
dilation_rate=layer.dilation_rate,
98+
padding=layer.padding,
99+
data_format=layer.data_format,
100+
)
101+
ch_out = kernel.shape[-1]
102+
ch_in_per_g = kernel.shape[-2]
103+
k_vol = int(prod(layer.kernel_size))
104+
out = _conv(
105+
x,
106+
kernel,
107+
k_vol=k_vol,
108+
groups=groups,
109+
ch_in_per_g=ch_in_per_g,
110+
out_per_g=ch_out // groups,
111+
)
112+
if bias.shape != ():
113+
out = out + bias
114+
if layer.data_format == 'channels_first':
115+
out = np.moveaxis(out, -1, 1) # type: ignore
116+
return replay_quantizer_if_enabled(layer, 'output_quantizer', out, 'quantize_output')
117+
118+
119+
class ReplayPQuantSeparableConv(ReplayOperationBase):
120+
handles = (PQSeparableConv2d,)
121+
122+
def call(self, inputs: FVArray) -> FVArray:
123+
layer = self.op
124+
x = ReplayPQuantConv(layer.depthwise_conv).call(inputs)
125+
return ReplayPQuantConv(layer.pointwise_conv).call(x)
126+
127+
128+
class ReplayPQuantBatchNormalization(ReplayBatchNormalization):
129+
handles = (PQBatchNormalization,)
130+
131+
def fused_scale_offset(self) -> tuple[np.ndarray, np.ndarray]:
132+
layer = self.op
133+
_assert_final_compression(layer)
134+
mean = to_numpy(keras.ops.cast(layer.moving_mean, layer.dtype))
135+
variance = to_numpy(keras.ops.cast(layer.moving_variance, layer.dtype))
136+
if layer.scale:
137+
gamma = to_numpy(keras.ops.cast(layer.gamma, layer.dtype))
138+
else:
139+
gamma = np.ones_like(mean)
140+
if layer.center:
141+
beta = to_numpy(keras.ops.cast(layer.beta, layer.dtype))
142+
else:
143+
beta = np.zeros_like(mean)
144+
scale = gamma / np.sqrt(variance + layer.epsilon)
145+
offset = beta - mean * scale
146+
return scale, offset
147+
148+
def call(self, inputs: FVArray, mask=None) -> FVArray:
149+
layer = self.op
150+
inputs = replay_quantizer_if_enabled(layer, 'input_quantizer', inputs, 'quantize_input')
151+
scale, offset = self.fused_scale_offset()
152+
shape = [1] * inputs.ndim
153+
axis = layer.axis if isinstance(layer.axis, (list, tuple)) else [layer.axis]
154+
for a in axis:
155+
aa = a if a >= 0 else inputs.ndim + a
156+
shape[aa] = inputs.shape[aa]
157+
out = inputs
158+
if not np.all(scale == 1):
159+
out = out * scale.reshape(shape) # type: ignore
160+
if not np.all(offset == 0):
161+
out = out + offset.reshape(shape) # type: ignore
162+
return out
163+
164+
165+
class ReplayPQuantAvgPool(ReplayPool):
166+
__activation_handled__ = True
167+
handles = (PQAvgPool1d, PQAvgPool2d)
168+
169+
def call(self, inputs: FVArray, mask: None = None) -> FVArray:
170+
layer = self.op
171+
inputs = replay_quantizer_if_enabled(layer, 'input_quantizer', inputs, 'quantize_input')
172+
out = super().call(inputs, mask=mask)
173+
return replay_quantizer_if_enabled(layer, 'output_quantizer', out, 'quantize_output')
174+
175+
176+
class ReplayPQuantActivation(ReplayOperationBase):
177+
__activation_handled__ = True
178+
handles = (PQActivation,)
179+
180+
def call(self, inputs: FVArray) -> FVArray:
181+
layer = self.op
182+
if (
183+
not bool(getattr(layer, 'use_hgq', False))
184+
and bool(getattr(layer, 'use_multiplier', False))
185+
and layer.activation_name == 'relu'
186+
and hasattr(layer, 'multiplier')
187+
):
188+
inputs = inputs * (2.0 ** np.rint(to_numpy(layer.multiplier)))
189+
inputs = replay_quantizer_if_enabled(layer, 'input_quantizer', inputs, 'quantize_input')
190+
if layer.activation_name not in keras_numpy_unary_map:
191+
raise PQuantAlkaidError(f'Unsupported PQuant activation for Alkaid conversion: {layer.activation_name!r}')
192+
out = keras_numpy_unary_map[layer.activation_name](inputs)
193+
return replay_quantizer_if_enabled(layer, 'output_quantizer', out, 'quantize_output')
194+
195+
196+
def _table_fn(table):
197+
"""Numpy-callable for a PQActivation lookup table, evaluated in float32 like the keras runtime."""
198+
fn = table.activation_function
199+
200+
def apply_fn(v: np.ndarray) -> np.ndarray:
201+
t = keras.ops.cast(keras.ops.convert_to_tensor(v), 'float32')
202+
return np.asarray(keras.ops.convert_to_numpy(fn(t)), dtype=np.float64)
203+
204+
return apply_fn
205+
206+
207+
class ReplayPQuantSoftmax(ReplayOperationBase):
208+
__activation_handled__ = True
209+
handles = (PQSoftmax,)
210+
211+
@staticmethod
212+
def _replay_table(table, x: FVArray) -> FVArray:
213+
if not (table.quantize_output and table.enable_quantization):
214+
raise PQuantAlkaidError(
215+
f'PQSoftmax table {table.name!r} must have an enabled output quantizer for Alkaid conversion.'
216+
)
217+
x = replay_quantizer_if_enabled(table, 'input_quantizer', x, 'quantize_input')
218+
out = x.apply(_table_fn(table))
219+
return replay_quantizer(table.output_quantizer, out)
220+
221+
def call(self, inputs: FVArray, mask=None) -> FVArray:
222+
layer = self.op
223+
if mask is not None:
224+
raise PQuantAlkaidError('PQSoftmax masks are not supported in Alkaid conversion.')
225+
inputs = replay_quantizer_if_enabled(layer, 'input_quantizer', inputs, 'quantize_input')
226+
if layer.stable:
227+
inputs = np.max(inputs, axis=layer.axes, keepdims=True) - inputs # type: ignore
228+
exp_inp = self._replay_table(layer.exp_table, inputs)
229+
sums = np.sum(exp_inp, axis=layer.axes, keepdims=True)
230+
divisor = self._replay_table(layer.inv_table, sums)
231+
out = exp_inp * divisor
232+
return replay_quantizer_if_enabled(layer, 'output_quantizer', out, 'quantize_output')
233+
234+
235+
class ReplayPQuantMultiheadAttention(ReplayOperationBase):
236+
__activation_handled__ = True
237+
handles = (PQMultiheadAttention,)
238+
239+
def call(self, inputs, key_padding_mask=None, attn_mask=None, need_weights=True):
240+
layer = self.op
241+
if key_padding_mask is not None or attn_mask is not None:
242+
raise PQuantAlkaidError('Attention masks are not supported in Alkaid conversion.')
243+
244+
if isinstance(inputs, (list, tuple)):
245+
if len(inputs) == 3:
246+
query, key, value = inputs
247+
elif len(inputs) == 2:
248+
query, key = inputs
249+
value = key
250+
else:
251+
query = key = value = inputs[0]
252+
else:
253+
query = key = value = inputs
254+
255+
batch_size, query_len = query.shape[0], query.shape[1]
256+
key_len = key.shape[1]
257+
num_heads, head_dim = layer.num_heads, layer.head_dim
258+
259+
q = ReplayPQuantDense(layer.q_proj).call(query) # (B, T, E)
260+
k = ReplayPQuantDense(layer.k_proj).call(key) # (B, S, E)
261+
v = ReplayPQuantDense(layer.v_proj).call(value) # (B, S, E)
262+
263+
# Reshape to (B, H, T/S, head_dim)
264+
q = q.reshape(batch_size, query_len, num_heads, head_dim).transpose(0, 2, 1, 3)
265+
k = k.reshape(batch_size, key_len, num_heads, head_dim).transpose(0, 2, 1, 3)
266+
v = v.reshape(batch_size, key_len, num_heads, head_dim).transpose(0, 2, 1, 3)
267+
268+
scale = float(np.float32(layer.scale))
269+
attn_scores = einsum('bhtd,bhsd->bhts', q, k) * scale
270+
271+
# The softmax's own input/output quantizers handle the scores and the attention weights
272+
attn_weights = ReplayPQuantSoftmax(layer.softmax).call(attn_scores)
273+
274+
# Weighted sum of values (dropout is an inference no-op): (B, H, T, head_dim)
275+
out = einsum('bhts,bhsd->bhtd', attn_weights, v)
276+
277+
# Merge heads: (B, T, E)
278+
out = out.transpose(0, 2, 1, 3).reshape(batch_size, query_len, layer.embed_dim)
279+
out = ReplayPQuantDense(layer.out_proj).call(out)
280+
281+
if need_weights:
282+
# Average attention weights over heads: (B, T, S)
283+
return out, np.mean(attn_weights, axis=1)
284+
return (out,)
285+
286+
287+
def register() -> None:
288+
"""Entry point for Alkaid's ``alkaid_keras`` second-level plugin group."""
289+
try:
290+
from alkaid.converter import _plugin_loader
291+
292+
_plugin_loader._LOADED.add(('pquant', 'keras'))
293+
except Exception:
294+
pass
295+
return None

0 commit comments

Comments
 (0)