Skip to content

Commit 2e3ac94

Browse files
authored
update sparsegpt and wandab (#127)
## Description <!-- Brief description of the changes --> ## Type of Change - [ ] Bug fix - [ ] New feature - [ ] Breaking change - [ ] Documentation ## Testing - [ ] Tests pass locally - [ ] New tests added (if applicable) ## Related Issues Closes #
2 parents bc61c3c + e4538af commit 2e3ac94

74 files changed

Lines changed: 6572 additions & 455 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ logs/
99
runs/
1010
outputs/
1111
results/
12-
experiments/
1312

1413

1514
# Backup files
@@ -161,7 +160,6 @@ dmypy.json
161160
/runs/
162161
/outputs/
163162
/results/
164-
/experiments/
165163

166164
# Temporary files
167165
*.tmp

configs/prune_llm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Configurations for generating results in the SCAR LLM pruning paper.
1515

1616
Run all experiments:
1717
```bash
18-
bash drafts/LLM_prune/paper/slurm/run_all_paper.sh
18+
bash slurm_jobs/prune_llm/run_all_paper.sh
1919
```
2020

2121
Run single model:

configs/prune_llm/llama2_7b_full.yaml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ llm:
7777

7878
evaluate_perplexity: true
7979
evaluation_num_samples: 100
80+
# Use NVIDIA Minitron official few-shot settings for downstream tasks.
81+
use_nvidia_fewshot: true
82+
# Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity:
83+
# concatenate full test set and evaluate in contiguous 2048-token blocks (no padding).
84+
perplexity_protocol: "oats"
85+
wikitext_subset: "wikitext-2-raw-v1"
86+
perplexity_seq_len: 2048
8087

8188
evaluation_metrics:
8289
- "perplexity"
@@ -137,6 +144,20 @@ supernode:
137144
core_fraction: 0.01
138145
follower_fraction: 0.10
139146
halo_fraction: 0.10
147+
# Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass
148+
# that lands on the top-K hidden dimensions most written-to by supernodes.
149+
connectivity_topk: 256
150+
# Optional post-processing for Conn (defaults keep current behavior)
151+
connectivity_rank_normalize: false
152+
connectivity_power: 1.0
153+
# Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels
154+
# (used for paper mechanism plots; does NOT affect pruning decisions).
155+
non_halo_sample_size: 256
156+
non_halo_sample_seed: 0
157+
# Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma)
158+
protection_normalization: "rank_power"
159+
protection_rank_power: 8.0
160+
protection_floor: 0.2
140161
protect_core: true
141162
protect_core_metrics:
142163
- "scar_loss_proxy" # SCAR-LP
@@ -232,7 +253,7 @@ pruning:
232253
dependency_aware: true
233254

234255
sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
235-
selection_modes: ["low", "high"]
256+
selection_modes: ["low"]
236257

237258
algorithms:
238259
- "rayleigh_quotient"

configs/prune_llm/llama2_7b_unified.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ cascade_analysis:
139139
pruning:
140140
enabled: true
141141
ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
142-
selection_modes: ["low", "high"]
142+
selection_modes: ["low"]
143143
distribution: "uniform"
144144
min_per_layer: 0.0
145145
max_per_layer: 0.95

configs/prune_llm/llama3_8b_full.yaml

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ llm:
8787

8888
evaluate_perplexity: true
8989
evaluation_num_samples: 100
90+
# Use NVIDIA Minitron official few-shot settings for downstream tasks
91+
# (MMLU 5-shot, HellaSwag 10-shot, ARC 25-shot, WinoGrande 5-shot, etc.).
92+
use_nvidia_fewshot: true
93+
# Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity:
94+
# concatenate full test set and evaluate in contiguous 2048-token blocks (no padding).
95+
perplexity_protocol: "oats"
96+
wikitext_subset: "wikitext-2-raw-v1"
97+
perplexity_seq_len: 2048
9098

9199
evaluation_metrics:
92100
# Language modeling
@@ -99,6 +107,7 @@ llm:
99107
- "accuracy_hellaswag"
100108
- "accuracy_arc_easy"
101109
- "accuracy_arc_challenge"
110+
- "accuracy_openbookqa"
102111

103112
# Common Sense
104113
- "accuracy_winogrande"
@@ -174,6 +183,21 @@ supernode:
174183
core_fraction: 0.01
175184
follower_fraction: 0.10
176185
halo_fraction: 0.10
186+
# Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass
187+
# that lands on the top-K hidden dimensions most written-to by supernodes.
188+
# (Avoids the ~1/hidden_dim collapse of L1-normalized dot-product overlap for dense matrices.)
189+
connectivity_topk: 256
190+
# Optional post-processing for Conn (defaults keep current behavior)
191+
connectivity_rank_normalize: false
192+
connectivity_power: 1.0
193+
# Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels
194+
# (used for paper mechanism plots; does NOT affect pruning decisions).
195+
non_halo_sample_size: 256
196+
non_halo_sample_seed: 0
197+
# Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma)
198+
protection_normalization: "rank_power"
199+
protection_rank_power: 8.0
200+
protection_floor: 0.2
177201
protect_core: true
178202
# Apply hard supernode protection only for the listed pruning metrics.
179203
# If omitted, legacy behavior is to protect for *all* pruning metrics.
@@ -286,7 +310,9 @@ pruning:
286310
dependency_aware: true
287311

288312
sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
289-
selection_modes: ["low", "high"]
313+
# We only report (and run) the standard pruning direction: prune *low*-scoring channels.
314+
# The "high" mode (prune highest scores) is a pathological control and is excluded from paper runs.
315+
selection_modes: ["low"]
290316

291317
# ALL algorithms including SOTA baselines
292318
algorithms:
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# ============================================================================
2+
# LLAMA-3.1-8B RANDOM (CHANNEL) BASELINE
3+
# ============================================================================
4+
#
5+
# Purpose:
6+
# - Fill the missing "Random (channel)" baseline row in paper tables.
7+
# - Run ONLY one pruning strategy (random) at 50% sparsity.
8+
# - Keep evaluation protocol consistent with the main paper run (few-shot settings, ppl protocol).
9+
#
10+
# This is intentionally lightweight: we skip SCAR analyses/plots and only do:
11+
# - Baseline eval
12+
# - Random structured channel pruning @ 50%
13+
# - Post-prune eval
14+
# ============================================================================
15+
16+
experiment:
17+
name: "llama3_8b_paper_results_random"
18+
type: "llm_alignment"
19+
output_dir: "./results/paper/llama3_8b_random"
20+
seed: 42
21+
device: "cuda"
22+
save_activations: false
23+
num_networks: 1
24+
25+
model:
26+
name: "hf_causal_lm"
27+
model_id: "meta-llama/Llama-3.1-8B"
28+
dtype: "bfloat16"
29+
device_map: "auto"
30+
trust_remote_code: true
31+
32+
dataset:
33+
name: "wikitext"
34+
batch_size: 1
35+
num_workers: 0
36+
37+
llm:
38+
evaluate_perplexity: true
39+
evaluation_num_samples: 100
40+
use_nvidia_fewshot: true
41+
perplexity_protocol: "oats"
42+
wikitext_subset: "wikitext-2-raw-v1"
43+
perplexity_seq_len: 2048
44+
45+
evaluation_metrics:
46+
- "perplexity"
47+
- "accuracy_openbookqa"
48+
- "accuracy_mmlu"
49+
- "accuracy_hellaswag"
50+
- "accuracy_piqa"
51+
- "accuracy_boolq"
52+
- "accuracy_winogrande"
53+
- "accuracy_arc_easy"
54+
- "accuracy_arc_challenge"
55+
56+
# Disable heavy analyses for this baseline-only run
57+
analysis:
58+
generate_plots: false
59+
save_scores: false
60+
61+
do_scar_metrics: false
62+
do_directed_redundancy: false
63+
do_connectivity_pruning: false
64+
do_halo_analysis: false
65+
do_generalized_importance: false
66+
67+
pruning:
68+
enabled: true
69+
target: "ffn"
70+
structured: true
71+
dependency_aware: true
72+
distribution: "uniform"
73+
min_per_layer: 0.0
74+
max_per_layer: 0.95
75+
76+
# Single point needed for table_full_benchmarks_50
77+
sparsity_levels: [0.5]
78+
79+
# Random structured pruning: selection done by mode="random"
80+
selection_modes: ["random"]
81+
82+
# Only one strategy for this run; scores are generated in-code (deterministic).
83+
algorithms:
84+
- "random"
85+
86+
single_strategy: "random"
87+

configs/prune_llm/llama3_8b_unified.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# - All experiment-specific settings in `extra:` section
1111
# - Same pruning/evaluation/visualization structure
1212
#
13-
# Usage: python scripts/run_experiment.py --config configs/unified/llama3_8b_unified.yaml
13+
# Usage: python scripts/run_experiment.py --config configs/prune_llm/llama3_8b_unified.yaml
1414
# Estimated runtime: ~6-8 hours on 1x A100
1515
# =============================================================================
1616

@@ -156,7 +156,7 @@ cascade_analysis:
156156
pruning:
157157
enabled: true
158158
ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
159-
selection_modes: ["low", "high"]
159+
selection_modes: ["low"]
160160
distribution: "uniform"
161161
min_per_layer: 0.0
162162
max_per_layer: 0.95

configs/prune_llm/mistral_7b_full.yaml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ llm:
7676

7777
evaluate_perplexity: true
7878
evaluation_num_samples: 100
79+
# Use NVIDIA Minitron official few-shot settings for downstream tasks.
80+
use_nvidia_fewshot: true
81+
# Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity:
82+
# concatenate full test set and evaluate in contiguous 2048-token blocks (no padding).
83+
perplexity_protocol: "oats"
84+
wikitext_subset: "wikitext-2-raw-v1"
85+
perplexity_seq_len: 2048
7986

8087
evaluation_metrics:
8188
- "perplexity"
@@ -136,6 +143,20 @@ supernode:
136143
core_fraction: 0.01
137144
follower_fraction: 0.10
138145
halo_fraction: 0.10
146+
# Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass
147+
# that lands on the top-K hidden dimensions most written-to by supernodes.
148+
connectivity_topk: 256
149+
# Optional post-processing for Conn (defaults keep current behavior)
150+
connectivity_rank_normalize: false
151+
connectivity_power: 1.0
152+
# Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels
153+
# (used for paper mechanism plots; does NOT affect pruning decisions).
154+
non_halo_sample_size: 256
155+
non_halo_sample_seed: 0
156+
# Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma)
157+
protection_normalization: "rank_power"
158+
protection_rank_power: 8.0
159+
protection_floor: 0.2
139160
protect_core: true
140161
protect_core_metrics:
141162
- "scar_loss_proxy" # SCAR-LP
@@ -231,7 +252,7 @@ pruning:
231252
dependency_aware: true
232253

233254
sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
234-
selection_modes: ["low", "high"]
255+
selection_modes: ["low"]
235256

236257
algorithms:
237258
- "rayleigh_quotient"

configs/prune_llm/mistral_7b_unified.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ cascade_analysis:
138138
pruning:
139139
enabled: true
140140
ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
141-
selection_modes: ["low", "high"]
141+
selection_modes: ["low"]
142142
distribution: "uniform"
143143
min_per_layer: 0.0
144144
max_per_layer: 0.95

configs/prune_llm/qwen2_7b_full.yaml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ llm:
7777

7878
evaluate_perplexity: true
7979
evaluation_num_samples: 100
80+
# Use NVIDIA Minitron official few-shot settings for downstream tasks.
81+
use_nvidia_fewshot: true
82+
# Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity:
83+
# concatenate full test set and evaluate in contiguous 2048-token blocks (no padding).
84+
perplexity_protocol: "oats"
85+
wikitext_subset: "wikitext-2-raw-v1"
86+
perplexity_seq_len: 2048
8087

8188
evaluation_metrics:
8289
- "perplexity"
@@ -137,6 +144,20 @@ supernode:
137144
core_fraction: 0.01
138145
follower_fraction: 0.10
139146
halo_fraction: 0.10
147+
# Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass
148+
# that lands on the top-K hidden dimensions most written-to by supernodes.
149+
connectivity_topk: 256
150+
# Optional post-processing for Conn (defaults keep current behavior)
151+
connectivity_rank_normalize: false
152+
connectivity_power: 1.0
153+
# Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels
154+
# (used for paper mechanism plots; does NOT affect pruning decisions).
155+
non_halo_sample_size: 256
156+
non_halo_sample_seed: 0
157+
# Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma)
158+
protection_normalization: "rank_power"
159+
protection_rank_power: 8.0
160+
protection_floor: 0.2
140161
protect_core: true
141162
protect_core_metrics:
142163
- "scar_loss_proxy" # SCAR-LP
@@ -232,7 +253,7 @@ pruning:
232253
dependency_aware: true
233254

234255
sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
235-
selection_modes: ["low", "high"]
256+
selection_modes: ["low"]
236257

237258
algorithms:
238259
- "rayleigh_quotient"

0 commit comments

Comments
 (0)