|
| 1 | +# ============================================================================ |
| 2 | +# GEMMA-2B COMPREHENSIVE PRUNING COMPARISON |
| 3 | +# ============================================================================ |
| 4 | +# Reference: https://arxiv.org/abs/2403.08295 (Gemma) |
| 5 | +# |
| 6 | +# PURPOSE: Test SCAR on smaller efficient model (faster experiments) |
| 7 | +# |
| 8 | +# Gemma-2B specs: |
| 9 | +# - 18 layers, 2048 hidden dim, 16384 intermediate dim |
| 10 | +# - Uses GeGLU activation (similar to SwiGLU) |
| 11 | +# - Efficient model good for rapid experimentation |
| 12 | +# |
| 13 | +# NOTE: Gemma uses slightly different layer naming than Llama |
| 14 | +# ============================================================================ |
| 15 | + |
| 16 | +experiment: |
| 17 | + name: "gemma2b_pruning" |
| 18 | + type: "llm_alignment" |
| 19 | + seed: 42 |
| 20 | + device: "cuda" |
| 21 | + output_dir: "./results/gemma2b_pruning" |
| 22 | + num_networks: 1 |
| 23 | + |
| 24 | +model: |
| 25 | + name: "hf_causal_lm" |
| 26 | + model_id: "google/gemma-2b" |
| 27 | + dtype: "bfloat16" |
| 28 | + device_map: "auto" |
| 29 | + |
| 30 | + # Gemma uses same MLP structure as Llama (GeGLU, similar to SwiGLU) |
| 31 | + # Layer naming: model.layers.*.mlp.{up_proj, gate_proj, down_proj} |
| 32 | + tracked_layers: |
| 33 | + - "model.model.layers.*.mlp.up_proj" |
| 34 | + - "model.model.layers.*.mlp.gate_proj" |
| 35 | + - "model.model.layers.*.mlp.down_proj" |
| 36 | + |
| 37 | +dataset: |
| 38 | + name: "wikitext" |
| 39 | + batch_size: 1 |
| 40 | + num_workers: 0 |
| 41 | + |
| 42 | +# ============================================================================ |
| 43 | +# IMPORTANCE METRICS |
| 44 | +# ============================================================================ |
| 45 | +metrics: |
| 46 | + enabled: |
| 47 | + - "rayleigh_quotient" |
| 48 | + - "gaussian_mi_analytic" |
| 49 | + - "average_redundancy" |
| 50 | + - "activation_l2_norm" |
| 51 | + |
| 52 | + num_samples: 64 |
| 53 | + |
| 54 | + rayleigh_quotient: |
| 55 | + relative: true |
| 56 | + regularization: 1.0e-6 |
| 57 | + |
| 58 | +# ============================================================================ |
| 59 | +# LLM-SPECIFIC SETTINGS |
| 60 | +# ============================================================================ |
| 61 | +llm: |
| 62 | + scar_metrics: true |
| 63 | + scar_num_samples: 64 |
| 64 | + scar_max_length: 512 |
| 65 | + |
| 66 | + evaluate_perplexity: true |
| 67 | + evaluation_num_samples: 200 |
| 68 | + |
| 69 | + use_nvidia_fewshot: true |
| 70 | + use_chain_of_thought: true |
| 71 | + |
| 72 | + evaluation_metrics: |
| 73 | + - "perplexity" |
| 74 | + - "loss" |
| 75 | + - "accuracy_winogrande" |
| 76 | + - "accuracy_arc_challenge" |
| 77 | + - "accuracy_mmlu" |
| 78 | + - "accuracy_hellaswag" |
| 79 | + - "accuracy_arc_easy" |
| 80 | + - "accuracy_piqa" |
| 81 | + - "accuracy_boolq" |
| 82 | + |
| 83 | +# ============================================================================ |
| 84 | +# SUPERNODE CONFIGURATION |
| 85 | +# ============================================================================ |
| 86 | +supernode: |
| 87 | + enabled: true |
| 88 | + core_fraction: 0.01 |
| 89 | + follower_fraction: 0.10 |
| 90 | + score_metric: "activation_l2_norm" |
| 91 | + protect_core: true |
| 92 | + cross_layer_analysis: true |
| 93 | + compare_by_connection: true |
| 94 | + |
| 95 | + compute_metrics: |
| 96 | + - "activation" |
| 97 | + - "rayleigh_quotient" |
| 98 | + - "mutual_information" |
| 99 | + - "redundancy" |
| 100 | + |
| 101 | +# ============================================================================ |
| 102 | +# SUPERNODE ROBUSTNESS ANALYSIS |
| 103 | +# ============================================================================ |
| 104 | +supernode_robustness: |
| 105 | + enabled: true |
| 106 | + supernode_fraction: 0.01 |
| 107 | + num_bootstrap_samples: 5 |
| 108 | + batch_size: 32 |
| 109 | + max_samples: 128 |
| 110 | + |
| 111 | + metrics: |
| 112 | + - "scar_activation_power" |
| 113 | + - "scar_loss_proxy" |
| 114 | + - "rayleigh_quotient" |
| 115 | + - "activation_l2_norm" |
| 116 | + |
| 117 | + target_layers: null |
| 118 | + |
| 119 | +# ============================================================================ |
| 120 | +# SUPERNODE SUMMARY ANALYSIS |
| 121 | +# ============================================================================ |
| 122 | +# Generates summary plots: |
| 123 | +# 1. Halo vs Non-Halo metrics by layer (mean activation, RQ, MI, redundancy) |
| 124 | +# 2. Supernode outlier z-scores by layer (how much of an outlier) |
| 125 | +supernode_summary: |
| 126 | + enabled: true |
| 127 | + outlier_analysis: true |
| 128 | + |
| 129 | +# ============================================================================ |
| 130 | +# PRUNING CONFIGURATION |
| 131 | +# ============================================================================ |
| 132 | +pruning: |
| 133 | + enabled: true |
| 134 | + |
| 135 | + sparsity_levels: [0.25, 0.5, 0.75] |
| 136 | + |
| 137 | + selection_modes: ["low", "high"] |
| 138 | + |
| 139 | + distribution: "uniform" |
| 140 | + structured: true |
| 141 | + dependency_aware: true |
| 142 | + |
| 143 | + algorithms: |
| 144 | + # Alignment-based |
| 145 | + - "rayleigh_quotient" |
| 146 | + - "gaussian_mi_analytic" |
| 147 | + - "average_redundancy" |
| 148 | + # SCAR |
| 149 | + - "scar_loss_proxy" |
| 150 | + # Supernode-aware |
| 151 | + - "supernode_protection_score" |
| 152 | + - "supernode_connectivity_score" |
| 153 | + # Baselines |
| 154 | + - "activation_l2_norm" |
| 155 | + - "wanda" |
| 156 | + - "sparsegpt" |
| 157 | + |
| 158 | + single_strategy: null |
| 159 | + |
| 160 | + fine_tune: |
| 161 | + enabled: false |
| 162 | + |
| 163 | +# ============================================================================ |
| 164 | +# ADVANCED ANALYSIS FLAGS |
| 165 | +# ============================================================================ |
| 166 | +do_directed_redundancy: true |
| 167 | +do_connectivity_pruning: true |
| 168 | + |
| 169 | +# ============================================================================ |
| 170 | +# ANALYSIS & VISUALIZATION |
| 171 | +# ============================================================================ |
| 172 | +analysis: |
| 173 | + save_scores: true |
| 174 | + generate_plots: true |
| 175 | + |
| 176 | + plots: |
| 177 | + histograms: true |
| 178 | + scatter_plots: true |
| 179 | + pruning_curves: true |
| 180 | + redundancy_heatmaps: true |
| 181 | + |
| 182 | + scatter_pairs: |
| 183 | + - ["activation_l2_norm", "rayleigh_quotient"] |
| 184 | + - ["activation_l2_norm", "gaussian_mi_analytic"] |
| 185 | + - ["scar_activation_power", "scar_loss_proxy"] |
| 186 | + - ["rayleigh_quotient", "scar_loss_proxy"] |
| 187 | + - ["average_redundancy", "rayleigh_quotient"] |
| 188 | + |
| 189 | +visualization: |
| 190 | + format: "png" |
| 191 | + dpi: 300 |
0 commit comments