|
17 | 17 | import functools |
18 | 18 | import json |
19 | 19 | import re |
20 | | -from typing import Tuple, Sequence |
| 20 | +from typing import Tuple, Sequence, Callable |
21 | 21 | from dataclasses import dataclass |
22 | 22 |
|
23 | 23 | from aqt.jax.v2 import config as aqt_config |
|
27 | 27 | from aqt.jax.v2 import calibration |
28 | 28 |
|
29 | 29 | import qwix |
| 30 | +from qwix._src.core import dot_general_qt |
30 | 31 |
|
31 | 32 | import jax |
32 | 33 | import jax.numpy as jnp |
@@ -194,6 +195,88 @@ def einsum(self, mesh_axes: Tuple[str, ...] = ()): |
194 | 195 | return aqt_einsum |
195 | 196 |
|
196 | 197 |
|
| 198 | +@dataclass |
| 199 | +class QwixQuantization: |
| 200 | + """Configures Qwix quantization github.com/google/qwix, for training only.""" |
| 201 | + |
| 202 | + quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN # needed by external call |
| 203 | + act_calibration_method: str = "absmax" |
| 204 | + weight_calibration_method: str = "absmax" |
| 205 | + bwd_calibration_method: str = "absmax" |
| 206 | + |
| 207 | + def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig: |
| 208 | + """Returns Qwix dot_general config for fp8_full quantization.""" |
| 209 | + return dot_general_qt.DotGeneralQtConfig( |
| 210 | + lhs_qtype=jnp.float8_e4m3fn, # activation |
| 211 | + rhs_qtype=jnp.float8_e4m3fn, # weight |
| 212 | + dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient |
| 213 | + drhs_grad_qtype=jnp.float8_e5m2, # weight gradient |
| 214 | + lhs_calibration_method=self.act_calibration_method, |
| 215 | + rhs_calibration_method=self.weight_calibration_method, |
| 216 | + dlhs_grad_calibration_method=self.bwd_calibration_method, |
| 217 | + drhs_grad_calibration_method=self.bwd_calibration_method, |
| 218 | + tile_size=None, |
| 219 | + ) |
| 220 | + |
| 221 | + def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): |
| 222 | + """Returns Qwix dot_general.""" |
| 223 | + return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config()) |
| 224 | + |
| 225 | + def einsum(self, mesh_axes: Tuple[str, ...] = ()): |
| 226 | + """Returns Qwix einsum.""" |
| 227 | + return QwixEinsum(config=self._get_fp8_full_qwix_config()) |
| 228 | + |
| 229 | + |
| 230 | +class QwixDotGeneral(nn.Module): |
| 231 | + """A callable class for Qwix dot_general.""" |
| 232 | + |
| 233 | + config: dot_general_qt.DotGeneralQtConfig |
| 234 | + |
| 235 | + @nn.compact |
| 236 | + def __call__( |
| 237 | + self, |
| 238 | + lhs: jax.Array, |
| 239 | + rhs: jax.Array, |
| 240 | + dimension_numbers: jax.lax.DotDimensionNumbers, |
| 241 | + precision: jax.lax.PrecisionLike = None, |
| 242 | + preferred_element_type: jax.typing.DTypeLike | None = None, |
| 243 | + *, |
| 244 | + out_sharding=None, |
| 245 | + ) -> jax.Array: |
| 246 | + |
| 247 | + return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config) |
| 248 | + |
| 249 | + |
| 250 | +class QwixEinsum(nn.Module): |
| 251 | + """A callable class for Qwix einsum.""" |
| 252 | + |
| 253 | + config: dot_general_qt.DotGeneralQtConfig |
| 254 | + |
| 255 | + @nn.compact |
| 256 | + def __call__( |
| 257 | + self, |
| 258 | + einsum_str: str, |
| 259 | + *operands: jax.Array, |
| 260 | + precision: jax.lax.PrecisionLike = None, |
| 261 | + preferred_element_type: jax.typing.DTypeLike | None = None, |
| 262 | + _dot_general: Callable[..., jax.Array] | None = None, |
| 263 | + out_sharding=None, |
| 264 | + ) -> jax.Array: |
| 265 | + |
| 266 | + def custom_dot_general(*args, **kwargs): |
| 267 | + return dot_general_qt.dot_general_qt(*args[:3], self.config) |
| 268 | + |
| 269 | + with jax.disable_jit(): |
| 270 | + return jnp.einsum( |
| 271 | + einsum_str, |
| 272 | + *operands, |
| 273 | + precision=precision, |
| 274 | + preferred_element_type=preferred_element_type, |
| 275 | + _dot_general=custom_dot_general, |
| 276 | + out_sharding=out_sharding, |
| 277 | + ) |
| 278 | + |
| 279 | + |
197 | 280 | @dataclass |
198 | 281 | class Fp8Quantization(Quantization): |
199 | 282 | """Configures Fp8 quantization for NVIDIA GPUs""" |
@@ -546,6 +629,15 @@ def get_quant_mode(quant_mode_str: str = "train"): |
546 | 629 |
|
547 | 630 | def configure_quantization(config: Config, quant_mode_str: str = "train"): |
548 | 631 | """Configure quantization based on user config and quant mode.""" |
| 632 | + if config.use_batch_split_schedule and config.quantization: |
| 633 | + if not (config.use_qwix_quantization and config.quantization == "fp8_full"): |
| 634 | + raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`") |
| 635 | + return QwixQuantization( |
| 636 | + weight_calibration_method=config.weight_quantization_calibration_method, |
| 637 | + act_calibration_method=config.act_quantization_calibration_method, |
| 638 | + bwd_calibration_method=config.bwd_quantization_calibration_method, |
| 639 | + ) |
| 640 | + |
549 | 641 | if config.use_qwix_quantization: |
550 | 642 | return None |
551 | 643 | quant_cfg = _get_quant_config(config) |
@@ -726,7 +818,8 @@ def get_qt_provider(config): |
726 | 818 |
|
727 | 819 | def maybe_quantize_model(model, config): |
728 | 820 | """Quantize the model if quantization is enabled.""" |
729 | | - if config.use_qwix_quantization: |
| 821 | + # Batch split is not using Qwix's interception feature but manual plumbing |
| 822 | + if config.use_qwix_quantization and not config.use_batch_split_schedule: |
730 | 823 | quantization_provider = get_qt_provider(config) |
731 | 824 | if quantization_provider: |
732 | 825 | model = qwix.quantize_model(model, quantization_provider) |
|
0 commit comments