Skip to content

Commit 714dbd2

Browse files
Activation hooks redesign (reuse hooks component across both minitron and puzzletron) (#1022)
### What does this PR do? Type of change: Redesign of existing feature This PR introduces a shared activation hooks infrastructure for minitron and puzzletron. The activation hooks framework provides a reusable component for collecting and analyzing activations during forward passes, which is used by both minitron pruning and puzzletron algorithms. Note! Minitron megatron.py/mcore_minitron.py:ImportanceEstimatorRegistry code does not use this component yet - will be refactored in a separate MR. **Key changes:** - Added `modelopt/torch/prune/importance_hooks` module with base hooks framework: - `base_hooks.py`: Core hook infrastructure for registering and managing forward hooks - `base_hooks_analysis.py`: Analysis utilities for processing collected activations - `megatron_hooks.py`: Megatron-specific hook implementations - `compare_module_outputs.py`: Utilities for comparing module outputs - Added unit tests in `tests/gpu/torch/prune/importance_hooks`: - `test_base_hooks.py`: Tests for base hooks functionality - `test_base_hooks_analysis.py`: Tests for activation analysis utilities - Updated `test_mcore_gpt_minitron_pruning.py` to validate activation collection - Updated test utilities for distributed testing support ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ Yes - This is a new module that doesn't affect existing functionality - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ No - Did you write any new necessary tests?: ✅ Yes - Added comprehensive tests for the activation hooks infrastructure - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ N/A - This is infrastructure code that will be used by subsequent PRs <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Robust JSON utilities for complex objects. * Flushed-print helper for synchronized output. * Activation-based importance estimation framework with multiple hook implementations, Megatron plugin support, and layer-output comparison tools. * Project-root test fixture for test suites. * **Tests** * Many new end-to-end tests validating hooks, evaluation metrics, and multi-layer output comparisons. * Improved pruning tests with deterministic initialization and additional statistical assertions. * **Chores** * Tests: mirror rank/size into environment for test init and disable external telemetry. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 20a46e0 commit 714dbd2

13 files changed

Lines changed: 1652 additions & 0 deletions

File tree

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Forward hooks for estimating importance scores for pruning."""
16+
17+
from modelopt.torch.utils import import_plugin
18+
19+
from .base_hooks import *
20+
from .base_hooks_analysis import *
21+
22+
with import_plugin("megatron_hooks"):
23+
from .plugins.megatron_hooks import *

0 commit comments

Comments
 (0)