Skip to content

Commit 0e856a0

Browse files
authored
initial Torch HGQ and pruning layers (#39)
initial Torch HGQ and pruning layers * Torch versions of HGQ Quantizer and pruning methods, used by the Torch PQlayers
1 parent d9fe442 commit 0e856a0

39 files changed

Lines changed: 2533 additions & 199 deletions

src/pquant/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# flake8: noqa
66
backend = os.getenv("KERAS_BACKEND", "tensorflow")
77
if backend == "torch":
8-
from . import configs, pruning_methods
8+
from . import configs
99
from .core.hyperparameter_optimization import (
1010
PQConfig,
1111
ap_config,
@@ -19,7 +19,7 @@
1919
pdp_config,
2020
wanda_config,
2121
)
22-
from .core.torch import activations, layers, optimizers, quantizer
22+
from .core.torch import activations, layers, optimizers, pruning_methods, quantizer
2323
from .core.torch.layers import (
2424
add_compression_layers,
2525
apply_final_compression,
@@ -61,7 +61,7 @@
6161
__all__ = _forwards
6262

6363
else:
64-
from . import configs, pruning_methods
64+
from . import configs
6565
from .core.hyperparameter_optimization import (
6666
PQConfig,
6767
ap_config,
@@ -74,7 +74,7 @@
7474
pdp_config,
7575
wanda_config,
7676
)
77-
from .core.keras import activations, layers, quantizer
77+
from .core.keras import activations, layers, pruning_methods, quantizer
7878
from .core.keras.layers import (
7979
add_compression_layers,
8080
apply_final_compression,

src/pquant/core/constants.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,6 @@
1010
PDPPruningModel,
1111
WandaPruningModel,
1212
)
13-
from pquant.pruning_methods.constraint_functions import (
14-
EqualityConstraint,
15-
GreaterThanOrEqualConstraint,
16-
LessThanOrEqualConstraint,
17-
)
18-
from pquant.pruning_methods.metric_functions import (
19-
StructuredSparsityMetric,
20-
UnstructuredSparsityMetric,
21-
)
2213

2314
PRUNING_MODEL_REGISTRY = {
2415
"cs": CSPruningModel,
@@ -53,15 +44,3 @@
5344
CONFIG_FILE = "config.yaml"
5445

5546
N_JOBS = 1
56-
57-
58-
METRIC_REGISTRY = {
59-
"UnstructuredSparsity": UnstructuredSparsityMetric,
60-
"StructuredSparsity": StructuredSparsityMetric,
61-
}
62-
63-
CONSTRAINT_REGISTRY = {
64-
"Equality": EqualityConstraint,
65-
"LessThanOrEqual": LessThanOrEqualConstraint,
66-
"GreaterThanOrEqual": GreaterThanOrEqualConstraint,
67-
}

src/pquant/core/keras/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pquant.core.hyperparameter_optimization import PQConfig
2626
from pquant.core.keras.activations import PQActivation
2727
from pquant.core.keras.quantizer import Quantizer
28-
from pquant.core.utils import get_pruning_layer
28+
from pquant.core.keras.utils import get_pruning_layer
2929

3030
T = TypeVar("T")
3131

File renamed without changes.

src/pquant/pruning_methods/activation_pruning.py renamed to src/pquant/core/keras/pruning_methods/activation_pruning.py

File renamed without changes.

src/pquant/pruning_methods/autosparse.py renamed to src/pquant/core/keras/pruning_methods/autosparse.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def call(self, weight):
113113
is_training = ops.logical_not(ops.logical_or(self.is_pretraining, self.is_finetuning))
114114
self.mask.assign(ops.where(is_training, new_binary_mask, ops.convert_to_tensor(self.mask)))
115115

116-
sparse_weight = ops.sign(weight) * ops.reshape(autosparse_prune(w_t, self.alpha), weight.shape)
116+
sparse_weight = ops.sign(weight) * ops.reshape(
117+
autosparse_prune(w_t, ops.convert_to_tensor(self.alpha)), weight.shape
118+
)
117119

118120
return ops.where(
119121
self.is_pretraining,

src/pquant/pruning_methods/constraint_functions.py renamed to src/pquant/core/keras/pruning_methods/constraint_functions.py

File renamed without changes.
File renamed without changes.
File renamed without changes.

src/pquant/pruning_methods/mdmm.py renamed to src/pquant/core/keras/pruning_methods/mdmm.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,26 @@
88
import keras
99
from keras import ops
1010

11-
from pquant.core.constants import CONSTRAINT_REGISTRY, METRIC_REGISTRY
11+
from pquant.core.keras.pruning_methods.constraint_functions import (
12+
EqualityConstraint,
13+
GreaterThanOrEqualConstraint,
14+
LessThanOrEqualConstraint,
15+
)
16+
from pquant.core.keras.pruning_methods.metric_functions import (
17+
StructuredSparsityMetric,
18+
UnstructuredSparsityMetric,
19+
)
20+
21+
METRIC_REGISTRY = {
22+
"UnstructuredSparsity": UnstructuredSparsityMetric,
23+
"StructuredSparsity": StructuredSparsityMetric,
24+
}
25+
26+
CONSTRAINT_REGISTRY = {
27+
"Equality": EqualityConstraint,
28+
"LessThanOrEqual": LessThanOrEqualConstraint,
29+
"GreaterThanOrEqual": GreaterThanOrEqualConstraint,
30+
}
1231

1332
# -------------------------------------------------------------------
1433
# MDMM Layer
@@ -28,6 +47,10 @@ def __init__(self, config, layer_type, *args, **kwargs):
2847
self.constraint_layer = None
2948
self._is_finetuning = False
3049
self._is_pretraining = True
50+
# TEMP: cache last penalty so calculate_additional_loss() works in
51+
# custom training loops via get_model_losses(). Remove once the
52+
# add_loss()/model.fit path is the only supported path.
53+
self._last_penalty = None
3154

3255
def build(self, input_shape):
3356
pruning_parameters = self.config.pruning_parameters
@@ -94,8 +117,11 @@ def call(self, weight):
94117
self.mask.assign(ops.where(not_active, ops.convert_to_tensor(self.mask), hard_mask))
95118

96119
penalty = ops.sum(self.constraint_layer(weight))
97-
self.add_loss(ops.where(not_active, ops.zeros_like(penalty), penalty))
98-
120+
gated_penalty = ops.where(not_active, ops.zeros_like(penalty), penalty)
121+
self.add_loss(gated_penalty)
122+
# TEMP: cache for calculate_additional_loss() — remove with the
123+
# _last_penalty attribute once custom-loop callers move to model.losses.
124+
self._last_penalty = gated_penalty
99125
return ops.where(self.is_finetuning, weight * hard_mask, weight)
100126

101127
def get_hard_mask(self, weight=None):
@@ -109,7 +135,12 @@ def get_layer_sparsity(self, weight):
109135

110136
def calculate_additional_loss(self):
111137
# Loss is added via self.add_loss() in call() for model.fit.
112-
# For custom training loops, accumulate model.losses from the last forward pass instead.
138+
# TEMP: also return the cached penalty so custom training loops using
139+
# get_model_losses() see the constraint term. Remove this branch (and
140+
# the _last_penalty cache) once those callers switch to model.losses;
141+
# then this can revert to `return 0.0`.
142+
if self._last_penalty is not None:
143+
return self._last_penalty
113144
return 0.0
114145

115146
def pre_epoch_function(self, epoch, total_epochs):

0 commit comments

Comments
 (0)