-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmdmm.py
More file actions
168 lines (141 loc) · 6.5 KB
/
Copy pathmdmm.py
File metadata and controls
168 lines (141 loc) · 6.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# @Author: Arghya Ranjan Das
# file: src/pquant/pruning_methods/mdmm.py
# modified by:
import inspect
import keras
from keras import ops
from pquant.core.constants import CONSTRAINT_REGISTRY, METRIC_REGISTRY
from pquant.pruning_methods.metric_functions import PACAPatternMetric
# -------------------------------------------------------------------
# MDMM Layer
# -------------------------------------------------------------------
@keras.saving.register_keras_serializable(package="PQuant")
class MDMM(keras.layers.Layer):
def __init__(self, config, layer_type, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(config, dict):
from pquant.core.hyperparameter_optimization import PQConfig
config = PQConfig.load_from_config(config)
self.config = config
self.layer_type = layer_type
self.constraint_layer = None
self._is_finetuning = False
self._is_pretraining = True
def build(self, input_shape):
pruning_parameters = self.config.pruning_parameters
metric_type = pruning_parameters.metric_type
constraint_type = pruning_parameters.constraint_type
target_value = pruning_parameters.target_value
target_sparsity = pruning_parameters.target_sparsity
l0_mode = pruning_parameters.l0_mode
scale_mode = pruning_parameters.scale_mode
candidate_kwargs = {
"epsilon": pruning_parameters.epsilon,
"target_sparsity": target_sparsity,
"l0_mode": l0_mode,
"scale_mode": scale_mode,
"rf": pruning_parameters.rf,
# FPGAAwareSparsityMetric
"precision": pruning_parameters.precision,
"target_resource": pruning_parameters.target_resource,
"bram_width": pruning_parameters.bram_width,
# PACAPatternMetric
"num_patterns_to_keep": pruning_parameters.num_patterns_to_keep,
"beta": pruning_parameters.beta,
"distance_metric": pruning_parameters.distance_metric,
}
metric_cls = METRIC_REGISTRY.get(metric_type)
sig = inspect.signature(getattr(metric_cls, "__init__", metric_cls))
metric_kwargs = {k: v for k, v in candidate_kwargs.items() if v is not None and k in sig.parameters}
if metric_cls:
metric_fn = metric_cls(**metric_kwargs)
else:
raise ValueError(f"Unknown metric_type: {metric_type}")
# PACA always pairs with EqualityConstraint at target 0 (preserves original design).
if metric_type == "PACAPatternSparsity":
constraint_type = "Equality"
target_value = 0.0
common_args = {
"metric_fn": metric_fn,
"target_value": target_value,
"scale": self.config.pruning_parameters.scale,
"damping": self.config.pruning_parameters.damping,
"use_grad": self.config.pruning_parameters.use_grad,
"lr": self.config.pruning_parameters.constraint_lr,
}
constraint_type_cls = CONSTRAINT_REGISTRY.get(constraint_type)
if constraint_type_cls:
self.constraint_layer = constraint_type_cls(**common_args)
else:
raise ValueError(f"Unknown constraint_type: {constraint_type}")
self.mask = self.add_weight(name="mask", shape=input_shape, initializer="ones", trainable=False)
self.is_pretraining = self.add_weight(
shape=(),
initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype),
name="is_pretraining",
trainable=False,
dtype="bool",
)
self.is_finetuning = self.add_weight(
shape=(),
initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype),
name="is_finetuning",
trainable=False,
dtype="bool",
)
self.constraint_layer.build(input_shape)
super().build(input_shape)
def call(self, weight):
epsilon = self.config.pruning_parameters.epsilon
# PACA needs a projection mask in finetuning. By then, dominant_patterns are
# guaranteed populated from prior active-phase constraint calls. In other phases
# use the generic abs-threshold mask.
is_paca = isinstance(self.constraint_layer.metric_fn, PACAPatternMetric)
if is_paca and self._is_finetuning:
hard_mask = self.constraint_layer.metric_fn.get_projection_mask(weight)
else:
hard_mask = ops.cast(ops.abs(weight) > epsilon, weight.dtype)
not_active = ops.logical_or(self.is_pretraining, self.is_finetuning)
self.mask.assign(ops.where(not_active, ops.convert_to_tensor(self.mask), hard_mask))
penalty = ops.sum(self.constraint_layer(weight))
self.add_loss(ops.where(not_active, ops.zeros_like(penalty), penalty))
return ops.where(self.is_finetuning, weight * hard_mask, weight)
def get_hard_mask(self, weight=None):
if weight is None:
return ops.convert_to_tensor(self.mask)
epsilon = self.config.pruning_parameters.epsilon
return ops.cast(ops.abs(weight) > epsilon, weight.dtype)
def get_layer_sparsity(self, weight):
mask = self.get_hard_mask(weight)
return ops.sum(mask) / ops.cast(ops.size(weight), mask.dtype)
def calculate_additional_loss(self):
# Loss is added via self.add_loss() in call() for model.fit.
# For custom training loops, accumulate model.losses from the last forward pass instead.
return 0.0
def pre_epoch_function(self, epoch, total_epochs):
pass
def pre_finetune_function(self):
self._is_finetuning = True
if hasattr(self, "is_finetuning"):
self.is_finetuning.assign(True)
if hasattr(self.constraint_layer, "module"):
self.constraint_layer.module.turn_off()
else:
self.constraint_layer.turn_off()
def post_epoch_function(self, epoch, total_epochs):
pass
def post_pre_train_function(self):
self._is_pretraining = False
if hasattr(self, "is_pretraining"):
self.is_pretraining.assign(False)
def post_round_function(self):
pass
def get_config(self):
config = super().get_config()
config.update(
{
"config": self.config.get_dict(),
"layer_type": self.layer_type,
}
)
return config