Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ config.quantization_parameters.default_weight_fractional_bits = 3.
config.quantization_parameters.use_relu_multiplier = False
```

### MDMM pruning (constraint-based)

MDMM (Modified Differential Method of Multipliers) prunes by enforcing a *constraint* on a chosen sparsity metric instead of adding a fixed penalty. The `metric_type` field picks what is constrained — including two hardware-aware options that target FPGA resources directly.

```python
from pquant import mdmm_config, mdmm_fpga_config, mdmm_paca_config

config = mdmm_config() # default: UnstructuredSparsity metric
config = mdmm_fpga_config() # hardware-aware: FPGAAwareSparsity (DSP/BRAM grouping)
config = mdmm_paca_config() # hardware-aware: PACAPatternSparsity (dominant conv-kernel patterns)

# switch the constrained metric and its parameters
config.pruning_parameters.metric_type = "FPGAAwareSparsity"
config.pruning_parameters.target_resource = "DSP" # or "BRAM"
config.pruning_parameters.rf = 4
```

Training proceeds in three phases handled by the training loop: a warm-up where the constraint is inactive, an active phase where the constraint loss is applied and the prune mask is tracked, and fine-tuning where the mask is frozen and applied. The available `metric_type` options and their parameters are listed in the [Usage Reference](reference.md).

### Building a model
PQuantML supports two ways of defining compressed models. Below we illustrate both approaches using a simple jet-tagging architecture.

Expand Down
39 changes: 38 additions & 1 deletion docs/source/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ There are more details about every pruning method:
| `pruning_method` | str | `mdmm` | Selects this pruning schema. |
| `constraint_type` | ConstraintType | `"Equality"` | Constraint form: equality / ≤ / ≥. |
| `target_value` | float | `0.0` | Target value for the chosen metric. |
| `metric_type` | MetricType | `"UnstructuredSparsity"` | Specifies which metric is constrained. |
| `metric_type` | MetricType | `"UnstructuredSparsity"` | Quantity the constraint acts on — see **MDMM metric types** below. |
| `target_sparsity` | float | `0.9` | Target sparsity when constraining sparsity. |
| `rf` | int | `1` | Regularization / frequency parameter. |
| `epsilon` | float | `1.0e-03` | Feasibility tolerance. |
Expand All @@ -180,6 +180,43 @@ There are more details about every pruning method:
| `scale_mode` | `"mean"` \| `"sum"` | `"mean"` | Aggregation mode for penalties. |


##### MDMM metric types

The `metric_type` field selects which quantity the MDMM constraint drives. The first two are magnitude-based; the last two are hardware-aware and act on 4D convolution kernels.

| **metric_type** | **Constrains** |
|------------------------|-------------------------------------------------------------------------------|
| `UnstructuredSparsity` | Element-wise (L0/L1) sparsity toward `target_sparsity`. |
| `StructuredSparsity` | Fraction of all-zero weight groups of size `rf`. |
| `FPGAAwareSparsity` | Fraction of zero DSP/BRAM weight groups, modelling FPGA resource packing. |
| `PACAPatternSparsity` | Mean distance of each conv kernel to a small set of dominant binary patterns. |

**`FPGAAwareSparsity` parameters** (used only when `metric_type: FPGAAwareSparsity`):

| **Field** | **Type** | **Default** | **Description** |
|-------------------|---------------------|-------------|-----------------------------------------------------------------------------|
| `precision` | int | `16` | Weight bit-width used to derive BRAM packing. |
| `target_resource` | `"DSP"` \| `"BRAM"` | `"DSP"` | Hardware resource whose group-sparsity is measured. |
| `bram_width` | int | `36` | BRAM word width; sets how many DSP groups pack into one BRAM (`BRAM` only). |

Weights are grouped into DSP blocks of size `rf`; for `target_resource: BRAM`, `c = bram_width // precision` (or `2*bram_width // precision` when not divisible) consecutive DSP groups pack into one BRAM block. The metric reports the fraction of such groups whose L2 norm is below `epsilon`.

**`PACAPatternSparsity` parameters** (used only when `metric_type: PACAPatternSparsity`):

| **Field** | **Type** | **Default** | **Description** |
|------------------------|-------------------------------------------------|--------------------|--------------------------------------------------------------|
| `num_patterns_to_keep` | int | `16` | Maximum number of dominant kernel patterns retained. |
| `beta` | float | `0.75` | Cumulative pattern-frequency coverage kept, in `[0, 1]`. |
| `distance_metric` | `"hamming"` \| `"valued_hamming"` \| `"cosine"` | `"valued_hamming"` | Distance from each kernel to its closest dominant pattern. |

```{note}
`PACAPatternSparsity` always pairs with an equality constraint at target `0` (driving every kernel onto a dominant pattern); the config model sets `constraint_type` and `target_value` for you. During fine-tuning the kernels are projected onto their closest dominant pattern.
```

```{note}
The hardware-aware metrics operate on 4D convolution weights; for non-convolutional layers `PACAPatternSparsity` is a no-op. Ready-made configs are available via `mdmm_fpga_config()` and `mdmm_paca_config()`.
```

Optionally, there is also FITCompress method implemented for PyTorch:
### FitCompress method
| **Field** | **Type** | **Default** | **Description** |
Expand Down
8 changes: 8 additions & 0 deletions src/pquant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
load_from_dictionary,
load_from_file,
mdmm_config,
mdmm_fpga_config,
mdmm_paca_config,
pdp_config,
wanda_config,
)
Expand Down Expand Up @@ -58,6 +60,8 @@
_forwards.append("cs_config")
_forwards.append("dst_config")
_forwards.append("mdmm_config")
_forwards.append("mdmm_fpga_config")
_forwards.append("mdmm_paca_config")
_forwards.append("pdp_config")
_forwards.append("wanda_config")
_forwards.append("fitcompress_config")
Expand All @@ -81,6 +85,8 @@
load_from_dictionary,
load_from_file,
mdmm_config,
mdmm_fpga_config,
mdmm_paca_config,
pdp_config,
wanda_config,
)
Expand Down Expand Up @@ -114,6 +120,8 @@
_forwards.append("cs_config")
_forwards.append("dst_config")
_forwards.append("mdmm_config")
_forwards.append("mdmm_fpga_config")
_forwards.append("mdmm_paca_config")
_forwards.append("pdp_config")
_forwards.append("wanda_config")
_forwards.append("load_from_file")
Expand Down
72 changes: 72 additions & 0 deletions src/pquant/configs/config_mdmm_fpga.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# file: /src/pquant/configs/config_mdmm_fpga.yaml
# MDMM pruning with the hardware-aware FPGA metric (DSP/BRAM resource grouping).

pruning_parameters:
pruning_method: mdmm
enable_pruning: true
disable_pruning_for_layers: [] # Disable pruning for these layers, even if enable_pruning is true
constraint_type: "Equality"
target_value: 0.0
metric_type: "FPGAAwareSparsity"
target_sparsity: 0.9
rf: 4
epsilon: 1.0e-03
scale: 50.0
damping: 1.0
use_grad: false
constraint_lr: 1.0e-3
# FPGAAwareSparsityMetric-specific
precision: 16 # weight bit-width
target_resource: "DSP" # "DSP" or "BRAM"
bram_width: 36 # BRAM word width (only used when target_resource == "BRAM")

quantization_parameters:
enable_quantization: true
default_weight_keep_negatives: 1.
default_weight_integer_bits: 0.
default_weight_fractional_bits: 7.
default_data_keep_negatives: 0.
default_data_integer_bits: 0.
default_data_fractional_bits: 7.
dynamic_data_quantization: false
granularity: "per_tensor"
quantize_input: true
quantize_output: false
hgq_beta: 1e-5
hgq_gamma: 0.0003
hgq_heterogeneous: True
layer_specific: {}
use_high_granularity_quantization: false
use_real_tanh: false
use_relu_multiplier: false
use_symmetric_quantization: false
overflow_mode_parameters: SAT
overflow_mode_data: SAT
round_mode: RND
training_parameters:
epochs: 200
fine_tuning_epochs: 30
pretraining_epochs: 0
pruning_first: true
rewind: never
rounds: 1
save_weights_epoch: -1
fitcompress_parameters:
enable_fitcompress : false
optimize_quantization : true
quantization_schedule : [7.,4.,3.,2]
pruning_schedule : {start : 0, end : -3, steps : 40}
compression_goal : 0.10
optimize_pruning : false
greedy_astar : true
approximate : true
f_lambda : 1
hpo_parameters:
experiment_name: experiment_name
model_name: jet_tagger
num_trials: 1
sampler:
type: RandomSampler
hyperparameter_search:
numerical: {}
categorical: {}
74 changes: 74 additions & 0 deletions src/pquant/configs/config_mdmm_paca.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# file: /src/pquant/configs/config_mdmm_paca.yaml
# MDMM pruning with the PACA pattern metric. PACA always pairs with an equality
# constraint at target 0 (enforced by the config model); the values below are listed
# for clarity.

pruning_parameters:
pruning_method: mdmm
enable_pruning: true
disable_pruning_for_layers: [] # Disable pruning for these layers, even if enable_pruning is true
constraint_type: "Equality" # Forced to Equality/0 for PACA by MDMMPruningModel.
target_value: 0.0
metric_type: "PACAPatternSparsity"
target_sparsity: 0.9
rf: 1
epsilon: 1.0e-03
scale: 1.0
damping: 1.0
use_grad: false
constraint_lr: 1.0e-3
# PACAPatternMetric-specific
num_patterns_to_keep: 16
beta: 0.85 # cumulative pattern-frequency coverage kept (0..1)
distance_metric: "cosine" # "hamming", "valued_hamming", or "cosine"

quantization_parameters:
enable_quantization: true
default_weight_keep_negatives: 1.
default_weight_integer_bits: 0.
default_weight_fractional_bits: 7.
default_data_keep_negatives: 0.
default_data_integer_bits: 0.
default_data_fractional_bits: 7.
dynamic_data_quantization: false
granularity: "per_tensor"
quantize_input: true
quantize_output: false
hgq_beta: 1e-5
hgq_gamma: 0.0003
hgq_heterogeneous: True
layer_specific: {}
use_high_granularity_quantization: false
use_real_tanh: false
use_relu_multiplier: false
use_symmetric_quantization: false
overflow_mode_parameters: SAT
overflow_mode_data: SAT
round_mode: RND
training_parameters:
epochs: 200
fine_tuning_epochs: 30
pretraining_epochs: 0
pruning_first: true
rewind: never
rounds: 1
save_weights_epoch: -1
fitcompress_parameters:
enable_fitcompress : false
optimize_quantization : true
quantization_schedule : [7.,4.,3.,2]
pruning_schedule : {start : 0, end : -3, steps : 40}
compression_goal : 0.10
optimize_pruning : false
greedy_astar : true
approximate : true
f_lambda : 1
hpo_parameters:
experiment_name: experiment_name
model_name: jet_tagger
num_trials: 1
sampler:
type: RandomSampler
hyperparameter_search:
numerical: {}
categorical: {}
17 changes: 17 additions & 0 deletions src/pquant/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,20 @@
CONFIG_FILE = "config.yaml"

N_JOBS = 1

# --- Hardware-aware pruning metric constants ---
# Conv-kernel layout -> axis index, used to canonicalise weight layouts in the PACA
# pattern utilities. Keras conv weights are HWIO, Torch conv weights are OIHW.
CONV_LAYOUT_AXES = {"H": 0, "W": 1, "I": 2, "O": 3}
CANONICAL_CONV_LAYOUT = "OIHW"

# PACAPatternMetric pattern-distance metrics
DISTANCE_HAMMING = "hamming"
DISTANCE_VALUED_HAMMING = "valued_hamming"
DISTANCE_COSINE = "cosine"
PACA_DISTANCE_METRICS = (DISTANCE_HAMMING, DISTANCE_VALUED_HAMMING, DISTANCE_COSINE)

# FPGAAwareSparsityMetric target hardware resources
TARGET_RESOURCE_DSP = "DSP"
TARGET_RESOURCE_BRAM = "BRAM"
FPGA_TARGET_RESOURCES = (TARGET_RESOURCE_DSP, TARGET_RESOURCE_BRAM)
14 changes: 14 additions & 0 deletions src/pquant/core/hyperparameter_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,20 @@ def mdmm_config():
return PQConfig.load_from_file(path)


def mdmm_fpga_config():
yaml_name = "config_mdmm_fpga.yaml"
parent = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
path = os.path.join(parent, "configs", yaml_name)
return PQConfig.load_from_file(path)


def mdmm_paca_config():
yaml_name = "config_mdmm_paca.yaml"
parent = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
path = os.path.join(parent, "configs", yaml_name)
return PQConfig.load_from_file(path)


def pdp_config():
yaml_name = "config_pdp.yaml"
parent = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Expand Down
28 changes: 26 additions & 2 deletions src/pquant/core/keras/pruning_methods/mdmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
LessThanOrEqualConstraint,
)
from pquant.core.keras.pruning_methods.metric_functions import (
FPGAAwareSparsityMetric,
PACAPatternMetric,
StructuredSparsityMetric,
UnstructuredSparsityMetric,
)

METRIC_REGISTRY = {
"UnstructuredSparsity": UnstructuredSparsityMetric,
"StructuredSparsity": StructuredSparsityMetric,
"FPGAAwareSparsity": FPGAAwareSparsityMetric,
"PACAPatternSparsity": PACAPatternMetric,
}

CONSTRAINT_REGISTRY = {
Expand Down Expand Up @@ -67,6 +71,14 @@ def build(self, input_shape):
"l0_mode": l0_mode,
"scale_mode": scale_mode,
"rf": pruning_parameters.rf,
# FPGAAwareSparsityMetric
"precision": pruning_parameters.precision,
"target_resource": pruning_parameters.target_resource,
"bram_width": pruning_parameters.bram_width,
# PACAPatternMetric
"num_patterns_to_keep": pruning_parameters.num_patterns_to_keep,
"beta": pruning_parameters.beta,
"distance_metric": pruning_parameters.distance_metric,
}

metric_cls = METRIC_REGISTRY.get(metric_type)
Expand Down Expand Up @@ -110,9 +122,18 @@ def build(self, input_shape):
self.constraint_layer.build(input_shape)
super().build(input_shape)

def _compute_hard_mask(self, weight, epsilon):
# During fine-tuning, a metric that defines its own projection (e.g. PACA pattern
# pruning) supplies the mask; otherwise use the magnitude threshold. The layer only
# checks for the capability, so it stays metric-agnostic (no metric-type branching).
metric_fn = getattr(self.constraint_layer, "metric_fn", None)
if self._is_finetuning and hasattr(metric_fn, "get_projection_mask"):
return ops.cast(metric_fn.get_projection_mask(weight), weight.dtype)
return ops.cast(ops.abs(weight) > epsilon, weight.dtype)

def call(self, weight):
epsilon = self.config.pruning_parameters.epsilon
hard_mask = ops.cast(ops.abs(weight) > epsilon, weight.dtype)
hard_mask = self._compute_hard_mask(weight, epsilon)
not_active = ops.logical_or(self.is_pretraining, self.is_finetuning)
self.mask.assign(ops.where(not_active, ops.convert_to_tensor(self.mask), hard_mask))

Expand All @@ -131,7 +152,10 @@ def get_hard_mask(self, weight=None):
return ops.cast(ops.abs(weight) > epsilon, weight.dtype)

def get_layer_sparsity(self, weight):
return ops.sum(self.get_hard_mask(weight)) / ops.size(weight)
# Cast size to the mask dtype: ops.sum(mask) is float but ops.size is int, and the
# TensorFlow backend rejects float/int division (the original float32/int32 bug).
mask = self.get_hard_mask(weight)
return ops.sum(mask) / ops.cast(ops.size(weight), mask.dtype)

def calculate_additional_loss(self):
# Loss is added via self.add_loss() in call() for model.fit.
Expand Down
Loading