-
Notifications
You must be signed in to change notification settings - Fork 423
Expand file tree
/
Copy pathconversion.py
More file actions
454 lines (354 loc) · 16.2 KB
/
conversion.py
File metadata and controls
454 lines (354 loc) · 16.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Conversion and restoration utilities for sparse attention."""
import fnmatch
from collections.abc import Callable
from typing import Any
import torch.nn as nn
from modelopt import __version__ as mo_version
from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager
from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict
from modelopt.torch.utils import atomic_print, get_unwrapped_name
from .config import SparseAttentionConfig
from .plugins import register_custom_model_plugins_on_the_fly
from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry
from .utils import get_named_sparse_attention_modules, get_sparse_attention_modules
def _register_triton_backend_if_needed(model: nn.Module, config: SparseAttentionConfig) -> None:
"""Register the Triton attention backend and set attn_implementation if needed.
When the config uses ``backend="triton"``, this function:
1. Registers the Triton kernel with HF's ``ALL_ATTENTION_FUNCTIONS``.
2. Sets ``model.config._attn_implementation = "modelopt_triton"`` so the
model dispatches to the Triton kernel at forward time.
This is called automatically during ``mtsa.sparsify()`` so users never need
to manually call ``register_triton_attention()`` or set ``attn_implementation``.
"""
sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {}
needs_triton = any(
isinstance(v, dict) and v.get("backend") == "triton" for v in sparse_cfg.values()
)
if not needs_triton:
return
from .kernels import register_triton_attention
if register_triton_attention is not None:
register_triton_attention()
# Set attn_implementation on the model so HF dispatches to the Triton kernel.
# HF's ALL_ATTENTION_FUNCTIONS is checked at forward time, not construction time,
# so this works even after the model is already loaded.
model_config = getattr(model, "config", None)
if model_config is not None:
model_config._attn_implementation = "modelopt_triton"
def is_attn_sparsified(model: nn.Module) -> bool:
"""Check if a model has sparse attention applied.
Similar to quantization's is_quantized for API consistency.
Args:
model: Model to check
Returns:
True if model contains any SparseAttentionModule instances
"""
return any(isinstance(module, SparseAttentionModule) for module in model.modules())
def convert_to_sparse_attention_model(
model: ModelLikeModule, config: SparseAttentionConfig
) -> ConvertReturnType:
"""Convert model to use sparse attention.
Args:
model: Model to convert
config: Sparse attention configuration
Returns:
Tuple of (converted_model, metadata)
"""
# Initialize the true module if necessary
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model
# Register Triton attention backend and set attn_implementation if needed
_register_triton_backend_if_needed(model, config)
# Apply custom model plugins
register_custom_model_plugins_on_the_fly(model)
# Replace attention modules with sparse versions
replace_sparse_attention_modules(model, version=ModeloptStateManager(model).state_version)
# Apply configuration to sparse attention modules
sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {}
set_sparse_attention_by_cfg(model, sparse_cfg)
# Create metadata
metadata = {}
update_sparse_attention_metadata(model, config, metadata)
return model, metadata
def replace_sparse_attention_modules(model: nn.Module, version=None):
"""Replace regular attention modules with sparse attention modules.
Recursively replace all attention modules in the model with their sparse attention counterparts.
Args:
model: Model to process
version: State version for tracking (optional)
"""
# Recursively replace modules
_replace_sparse_attention_modules(model, version=version)
# Count and report replaced modules
replaced_count = len(get_sparse_attention_modules(model))
if replaced_count > 0:
print(f"Inserted {replaced_count} sparse attention modules")
def _replace_sparse_attention_modules(model: nn.Module, version=None):
"""Helper function for replace_sparse_attention_modules."""
for name, child in model.named_children():
if type(child) in SparseAttentionRegistry:
# REPLACE on the parent (model), not on child
sparse_module = SparseAttentionRegistry.convert(child)
setattr(model, name, sparse_module)
# Now recurse into whichever module is now at `model.name`
_replace_sparse_attention_modules(getattr(model, name), version=version)
def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict):
"""Apply sparse attention configuration to model.
Similar to quantization's set_quantizer_by_cfg.
Args:
model: Model with sparse attention modules
sparse_cfg: Sparse configuration dictionary mapping patterns to attributes
"""
sparse_cfg = sparse_cfg.copy()
# Apply default first if exists
if "default" in sparse_cfg:
set_sparse_attention_attribute(model, "*", sparse_cfg["default"])
sparse_cfg.pop("default")
# Apply pattern-specific configs
for pattern, cfg in sparse_cfg.items():
set_sparse_attention_attribute(model, pattern, cfg)
def set_sparse_attention_attribute(
model: nn.Module,
wildcard_or_filter: str | Callable,
attribute_cfg: dict[str, Any],
):
"""Set sparse attention attributes for modules matching pattern.
Similar to quantization's set_quantizer_attribute.
Args:
model: Model to configure
wildcard_or_filter: Pattern to match module names
attribute_cfg: Attributes to apply (must include 'method')
"""
# Filter out model-level configs that shouldn't be passed to modules
module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"}
for name, module in get_named_sparse_attention_modules(model):
# Check pattern match
matched = False
if isinstance(wildcard_or_filter, str):
matched = fnmatch.fnmatch(name, wildcard_or_filter)
elif callable(wildcard_or_filter):
matched = wildcard_or_filter(name)
else:
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter)}")
if matched:
# Apply config using the same method as TensorQuantizer
module.set_from_attribute_config(module_cfg)
def restore_sparse_attention_model(
model: ModelLikeModule, config: SparseAttentionConfig, metadata: MetadataDict
) -> nn.Module:
"""Restore sparse attention model from saved state.
Args:
model: Model to restore
config: Sparse attention configuration
metadata: Saved metadata
Returns:
Restored model
"""
# Convert to sparse attention model
model, _ = convert_to_sparse_attention_model(model, config)
# Restore sparse attention state from metadata
if "sparse_attention_state" in metadata:
restore_sparse_attention_state(model, metadata["sparse_attention_state"])
return model
def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]):
"""Restore sparse attention state from state dict.
Args:
model: Model with sparse attention modules
state_dict: Saved state dictionary
"""
for name, module in get_named_sparse_attention_modules(model):
module_name = get_unwrapped_name(name, model)
if module_name in state_dict:
module_state = state_dict[module_name]
# Restore method and config
if "method" in module_state:
module._method = module_state["method"]
if "method_config" in module_state:
# Restore config attributes
for key, val in module_state["method_config"].items():
setattr(module, f"_{key}", val)
# Re-setup with restored config
module._setup()
def update_sparse_attention_metadata(
model: nn.Module, config: SparseAttentionConfig, metadata: MetadataDict
) -> None:
"""Update metadata with sparse attention state.
Args:
model: Model with sparse attention
config: Configuration used
metadata: Metadata dict to update
"""
sparse_state = {}
for name, module in get_named_sparse_attention_modules(model):
module_name = get_unwrapped_name(name, model)
# Save the method configuration that was used
# _method_config already contains the validated config dict
module_state = {
"method": module._sparse_method_instance.name,
"method_config": module._method_config.copy(),
}
sparse_state[module_name] = module_state
metadata["sparse_attention_state"] = sparse_state
metadata["sparse_attention_config"] = (
config.model_dump() if hasattr(config, "model_dump") else vars(config)
)
def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None:
"""Extract sparse attention config for export to config.json.
Extracts the calibration parameters (a, b) for the exponential threshold model
from the first sparse attention module that has calibrated thresholds.
The exported config allows computing threshold at runtime:
scale_factor = a * exp(b * target_sparsity)
threshold = scale_factor / seqlen
Args:
model: Model with sparse attention applied
Returns:
Dictionary with sparse attention config for HuggingFace config.json export.
Returns None if no calibrated sparse attention modules found.
Example output::
{
"config_groups": {
"group_0": {"sparse_algo": "softmax_skip", "targets": ["LlamaAttention"]}
},
"threshold_scale_factor": {
"formula": "a * exp(b * target_sparsity)",
"prefill": {"a": 7.93, "b": 8.61},
"decode": {"a": 0.12, "b": 9.85},
},
"producer": {"name": "modelopt", "version": "0.37.0"},
}
"""
# Collect sparse attention module info
calibration_params = None
target_classes: set[str] = set()
for module in get_sparse_attention_modules(model):
# Get the original wrapped module's class name
if hasattr(module, "get_original_cls_by_level"):
original_cls = module.get_original_cls_by_level(level=0)
if original_cls is not None:
target_classes.add(original_cls.__name__)
# Get calibration params from first module that has them
if calibration_params is None:
calibration_params = getattr(module._sparse_method_instance, "calibration_params", None)
# Return None if no calibration params found
if calibration_params is None:
return None
# Build threshold_scale_factor with model parameters
threshold_scale_factor: dict[str, Any] = {
"formula": "a * exp(b * target_sparsity)",
}
for phase in ["prefill", "decode"]:
if phase in calibration_params:
threshold_scale_factor[phase] = {
"a": calibration_params[phase]["a"],
"b": calibration_params[phase]["b"],
}
# Build the export config
export_config: dict[str, Any] = {
"config_groups": {
"group_0": {
"sparse_algo": "softmax_skip",
"targets": sorted(target_classes) if target_classes else ["Attention"],
}
},
"threshold_scale_factor": threshold_scale_factor,
"producer": {
"name": "modelopt",
"version": mo_version,
},
}
return export_config
def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable):
"""Disable sparse attention for matching modules.
Similar to mtq.disable_quantizer for API consistency.
Args:
model: Model with sparse attention applied
wildcard_or_filter_func: Wildcard string or filter function to match module names.
For example: "*lm_head*", "*layer_0*", etc.
Example:
>>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn
>>> model = sparse_attn.sparsify(model, config)
>>> # Disable sparse attention for lm_head
>>> sparse_attn.disable_sparse_attention(model, "*lm_head*")
"""
for name, module in get_named_sparse_attention_modules(model):
matched = False
if isinstance(wildcard_or_filter_func, str):
matched = fnmatch.fnmatch(name, wildcard_or_filter_func)
elif callable(wildcard_or_filter_func):
matched = wildcard_or_filter_func(name)
if matched:
module.disable()
def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable):
"""Enable sparse attention for matching modules.
Similar to mtq.enable_quantizer for API consistency.
Args:
model: Model with sparse attention applied
wildcard_or_filter_func: Wildcard string or filter function to match module names.
For example: "*attention*", "*attn*", etc.
Example:
>>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn
>>> model = sparse_attn.sparsify(model, config)
>>> # Re-enable sparse attention for all attention modules
>>> sparse_attn.enable_sparse_attention(model, "*attention*")
"""
for name, module in get_named_sparse_attention_modules(model):
matched = False
if isinstance(wildcard_or_filter_func, str):
matched = fnmatch.fnmatch(name, wildcard_or_filter_func)
elif callable(wildcard_or_filter_func):
matched = wildcard_or_filter_func(name)
if matched:
module.enable()
def _format_threshold(info: dict) -> str:
"""Format threshold info for display."""
t = info.get("type")
if t == "dynamic_calibrated":
# Exponential model: threshold = a * exp(b * sparsity) / seqlen
params = info.get("calibration_params", {})
target = info.get("target_sparse_ratio", {})
parts = []
for phase in ["prefill", "decode"]:
if phase in params:
a, b = params[phase]["a"], params[phase]["b"]
s = target.get(phase, 0.5)
parts.append(f"{phase}: a={a:.4f}, b={b:.2f}, target={s:.0%}")
return f"calibrated({', '.join(parts)})"
if t == "static":
v = info.get("value")
if isinstance(v, dict):
return f"threshold={v}"
return f"threshold={v:.2e}" if isinstance(v, float) else f"threshold={v}"
return "threshold=N/A"
@atomic_print
def print_sparse_attention_summary(model: nn.Module):
"""Print summary of sparse attention modules in the model.
Args:
model: Model with sparse attention applied
"""
sparse_modules = get_named_sparse_attention_modules(model)
if not sparse_modules:
print("No sparse attention modules found")
return
enabled = sum(1 for _, m in sparse_modules if m.is_enabled)
print(f"Sparse attention: {enabled}/{len(sparse_modules)} modules enabled")
# Group by (method, threshold)
groups: dict[tuple[str, str], int] = {}
for _, module in sparse_modules:
method = getattr(module, "_method", "unknown")
threshold = _format_threshold(module.get_threshold_info())
groups[(method, threshold)] = groups.get((method, threshold), 0) + 1
for (method, threshold), count in sorted(groups.items()):
print(f" {method}: {count} layers, {threshold}")