Skip to content

Commit be42755

Browse files
authored
add halo analysis, test more models and benchmarks, add cross-layer
## Description <!-- Brief description of the changes --> ## Type of Change - [ ] Bug fix - [ ] New feature - [ ] Breaking change - [ ] Documentation ## Testing - [ ] Tests pass locally - [ ] New tests added (if applicable) ## Related Issues Closes #
2 parents ae5d20c + 90464eb commit be42755

42 files changed

Lines changed: 7966 additions & 84 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# ============================================================================
2+
# GEMMA-2B COMPREHENSIVE PRUNING COMPARISON
3+
# ============================================================================
4+
# Reference: https://arxiv.org/abs/2403.08295 (Gemma)
5+
#
6+
# PURPOSE: Test SCAR on smaller efficient model (faster experiments)
7+
#
8+
# Gemma-2B specs:
9+
# - 18 layers, 2048 hidden dim, 16384 intermediate dim
10+
# - Uses GeGLU activation (similar to SwiGLU)
11+
# - Efficient model good for rapid experimentation
12+
#
13+
# NOTE: Gemma uses slightly different layer naming than Llama
14+
# ============================================================================
15+
16+
experiment:
17+
name: "gemma2b_pruning"
18+
type: "llm_alignment"
19+
seed: 42
20+
device: "cuda"
21+
output_dir: "./results/gemma2b_pruning"
22+
num_networks: 1
23+
24+
model:
25+
name: "hf_causal_lm"
26+
model_id: "google/gemma-2b"
27+
dtype: "bfloat16"
28+
device_map: "auto"
29+
30+
# Gemma uses same MLP structure as Llama (GeGLU, similar to SwiGLU)
31+
# Layer naming: model.layers.*.mlp.{up_proj, gate_proj, down_proj}
32+
tracked_layers:
33+
- "model.model.layers.*.mlp.up_proj"
34+
- "model.model.layers.*.mlp.gate_proj"
35+
- "model.model.layers.*.mlp.down_proj"
36+
37+
dataset:
38+
name: "wikitext"
39+
batch_size: 1
40+
num_workers: 0
41+
42+
# ============================================================================
43+
# IMPORTANCE METRICS
44+
# ============================================================================
45+
metrics:
46+
enabled:
47+
- "rayleigh_quotient"
48+
- "gaussian_mi_analytic"
49+
- "average_redundancy"
50+
- "activation_l2_norm"
51+
52+
num_samples: 64
53+
54+
rayleigh_quotient:
55+
relative: true
56+
regularization: 1.0e-6
57+
58+
# ============================================================================
59+
# LLM-SPECIFIC SETTINGS
60+
# ============================================================================
61+
llm:
62+
scar_metrics: true
63+
scar_num_samples: 64
64+
scar_max_length: 512
65+
66+
evaluate_perplexity: true
67+
evaluation_num_samples: 200
68+
69+
use_nvidia_fewshot: true
70+
use_chain_of_thought: true
71+
72+
evaluation_metrics:
73+
- "perplexity"
74+
- "loss"
75+
- "accuracy_winogrande"
76+
- "accuracy_arc_challenge"
77+
- "accuracy_mmlu"
78+
- "accuracy_hellaswag"
79+
- "accuracy_arc_easy"
80+
- "accuracy_piqa"
81+
- "accuracy_boolq"
82+
83+
# ============================================================================
84+
# SUPERNODE CONFIGURATION
85+
# ============================================================================
86+
supernode:
87+
enabled: true
88+
core_fraction: 0.01
89+
follower_fraction: 0.10
90+
score_metric: "activation_l2_norm"
91+
protect_core: true
92+
cross_layer_analysis: true
93+
compare_by_connection: true
94+
95+
compute_metrics:
96+
- "activation"
97+
- "rayleigh_quotient"
98+
- "mutual_information"
99+
- "redundancy"
100+
101+
# ============================================================================
102+
# SUPERNODE ROBUSTNESS ANALYSIS
103+
# ============================================================================
104+
supernode_robustness:
105+
enabled: true
106+
supernode_fraction: 0.01
107+
num_bootstrap_samples: 5
108+
batch_size: 32
109+
max_samples: 128
110+
111+
metrics:
112+
- "scar_activation_power"
113+
- "scar_loss_proxy"
114+
- "rayleigh_quotient"
115+
- "activation_l2_norm"
116+
117+
target_layers: null
118+
119+
# ============================================================================
120+
# SUPERNODE SUMMARY ANALYSIS
121+
# ============================================================================
122+
# Generates summary plots:
123+
# 1. Halo vs Non-Halo metrics by layer (mean activation, RQ, MI, redundancy)
124+
# 2. Supernode outlier z-scores by layer (how much of an outlier)
125+
supernode_summary:
126+
enabled: true
127+
outlier_analysis: true
128+
129+
# ============================================================================
130+
# PRUNING CONFIGURATION
131+
# ============================================================================
132+
pruning:
133+
enabled: true
134+
135+
sparsity_levels: [0.25, 0.5, 0.75]
136+
137+
selection_modes: ["low", "high"]
138+
139+
distribution: "uniform"
140+
structured: true
141+
dependency_aware: true
142+
143+
algorithms:
144+
# Alignment-based
145+
- "rayleigh_quotient"
146+
- "gaussian_mi_analytic"
147+
- "average_redundancy"
148+
# SCAR
149+
- "scar_loss_proxy"
150+
# Supernode-aware
151+
- "supernode_protection_score"
152+
- "supernode_connectivity_score"
153+
# Baselines
154+
- "activation_l2_norm"
155+
- "wanda"
156+
- "sparsegpt"
157+
158+
single_strategy: null
159+
160+
fine_tune:
161+
enabled: false
162+
163+
# ============================================================================
164+
# ADVANCED ANALYSIS FLAGS
165+
# ============================================================================
166+
do_directed_redundancy: true
167+
do_connectivity_pruning: true
168+
169+
# ============================================================================
170+
# ANALYSIS & VISUALIZATION
171+
# ============================================================================
172+
analysis:
173+
save_scores: true
174+
generate_plots: true
175+
176+
plots:
177+
histograms: true
178+
scatter_plots: true
179+
pruning_curves: true
180+
redundancy_heatmaps: true
181+
182+
scatter_pairs:
183+
- ["activation_l2_norm", "rayleigh_quotient"]
184+
- ["activation_l2_norm", "gaussian_mi_analytic"]
185+
- ["scar_activation_power", "scar_loss_proxy"]
186+
- ["rayleigh_quotient", "scar_loss_proxy"]
187+
- ["average_redundancy", "rayleigh_quotient"]
188+
189+
visualization:
190+
format: "png"
191+
dpi: 300
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# ============================================================================
2+
# GPT-2 FAST PRUNING TEST
3+
# ============================================================================
4+
# PURPOSE: Quick validation of pruning pipeline on small model
5+
#
6+
# GPT-2 (124M) specs:
7+
# - 12 layers, 768 hidden dim, 3072 intermediate dim
8+
# - Uses standard MLP (fc1 -> GELU -> fc2), NOT SwiGLU
9+
# - Very fast for testing (~10-30 min on H100)
10+
#
11+
# NOTE: GPT-2 uses different MLP naming than Llama/Mistral:
12+
# - h.*.mlp.c_fc (up projection)
13+
# - h.*.mlp.c_proj (down projection)
14+
# - No gate_proj (standard MLP, not gated)
15+
#
16+
# Use this config to quickly test pipeline before running larger models
17+
# ============================================================================
18+
19+
experiment:
20+
name: "gpt2_fast_test"
21+
type: "llm_alignment"
22+
seed: 42
23+
device: "cuda"
24+
output_dir: "./results/gpt2_fast_test"
25+
num_networks: 1
26+
27+
model:
28+
name: "hf_causal_lm"
29+
model_id: "gpt2" # 124M params - very fast
30+
dtype: "float16" # GPT-2 works better with fp16 than bf16
31+
device_map: "auto"
32+
33+
# GPT-2 uses standard MLP (not SwiGLU)
34+
# c_fc is the up projection, c_proj is the down projection
35+
tracked_layers:
36+
- "model.transformer.h.*.mlp.c_fc"
37+
- "model.transformer.h.*.mlp.c_proj"
38+
39+
dataset:
40+
name: "wikitext"
41+
batch_size: 1
42+
num_workers: 0
43+
44+
# ============================================================================
45+
# IMPORTANCE METRICS (reduced for fast test)
46+
# ============================================================================
47+
metrics:
48+
enabled:
49+
- "rayleigh_quotient"
50+
- "activation_l2_norm"
51+
52+
num_samples: 16 # Fewer samples for speed
53+
54+
rayleigh_quotient:
55+
relative: true
56+
regularization: 1.0e-6
57+
58+
# ============================================================================
59+
# LLM-SPECIFIC SETTINGS (reduced for speed)
60+
# ============================================================================
61+
llm:
62+
scar_metrics: true
63+
scar_num_samples: 16 # Fewer samples
64+
scar_max_length: 256 # Shorter sequences
65+
66+
evaluate_perplexity: true
67+
evaluation_num_samples: 50 # Fewer eval samples
68+
69+
use_nvidia_fewshot: false # Skip few-shot for speed
70+
71+
# Minimal benchmarks for quick test
72+
evaluation_metrics:
73+
- "perplexity"
74+
- "loss"
75+
- "accuracy_hellaswag" # Just one benchmark
76+
77+
# ============================================================================
78+
# SUPERNODE CONFIGURATION
79+
# ============================================================================
80+
supernode:
81+
enabled: true
82+
core_fraction: 0.01
83+
follower_fraction: 0.10
84+
score_metric: "activation_l2_norm"
85+
protect_core: true
86+
cross_layer_analysis: false # Skip for speed
87+
compare_by_connection: false
88+
89+
compute_metrics:
90+
- "activation"
91+
92+
# ============================================================================
93+
# SUPERNODE ROBUSTNESS (disabled for fast test)
94+
# ============================================================================
95+
supernode_robustness:
96+
enabled: false
97+
98+
# ============================================================================
99+
# PRUNING CONFIGURATION (reduced for speed)
100+
# ============================================================================
101+
pruning:
102+
enabled: true
103+
104+
# Just two sparsity levels for quick test
105+
sparsity_levels: [0.25, 0.5]
106+
107+
selection_modes: ["low"] # Just one mode for speed
108+
109+
distribution: "uniform"
110+
structured: true
111+
dependency_aware: true
112+
113+
# Reduced algorithm set
114+
algorithms:
115+
- "rayleigh_quotient"
116+
- "activation_l2_norm"
117+
- "wanda"
118+
119+
single_strategy: null
120+
121+
fine_tune:
122+
enabled: false
123+
124+
# ============================================================================
125+
# ADVANCED ANALYSIS (disabled for speed)
126+
# ============================================================================
127+
do_directed_redundancy: false
128+
do_connectivity_pruning: false
129+
130+
# ============================================================================
131+
# ANALYSIS & VISUALIZATION
132+
# ============================================================================
133+
analysis:
134+
save_scores: true
135+
generate_plots: true
136+
137+
plots:
138+
histograms: true
139+
scatter_plots: false # Skip for speed
140+
pruning_curves: true
141+
redundancy_heatmaps: false
142+
143+
visualization:
144+
format: "png"
145+
dpi: 150 # Lower DPI for speed

0 commit comments

Comments
 (0)