This project uses Hydra for configuration. The configuration is structured and composable.
The configuration is split into groups, defined in conf/:
config.yaml: Main configuration and defaults.model/: Model architecture settings.data/: Data generation settings.sampler/: Sampler settings (HMC, SGLD, MCLMC parameters).
You can create new presets by adding YAML files to these directories.
Controls the type of target function for LLC estimation.
target=mlp: Multi-Layer Perceptron (default) using Haikutarget=quadratic: Simple quadratic target L_n(θ) = 0.5||θ||² for testing
Example:
uv run python train.py target=quadratic quad_dim=4Controls the architecture of the student MLP (when target=mlp).
model.depth: Number of hidden layers.model.target_params: Approximate total number of parameters (widths are inferred).model.activation: Activation function (e.g.,relu,tanh,gelu,identityfor deep-linear).
Example:
uv run python train.py model.depth=4 model.target_params=5000 model.activation=identityControls the synthetic data generation process using a teacher-student setup.
data.n_data: Number of data points.data.x_dist: Input distribution:gauss_iso: Isotropic Gaussiangauss_aniso: Anisotropic Gaussianmixture: Mixture of Gaussianslowdim_manifold: Low-dimensional manifoldheavy_tail: Heavy-tailed distribution
data.noise_model: Noise model:gauss: Gaussian noisehetero: Heteroscedastic noisestudent_t: Student-t noiseoutliers: Noise with outliers
Example:
uv run python train.py data.n_data=10000 data.noise_model=student_t data.x_dist=mixtureControls the ERM (Empirical Risk Minimization) training process.
training.optimizer: Optimizer (default:adam)training.learning_rate: Learning rate for optimizationtraining.erm_steps: Number of training steps
Controls the posterior configuration for LLC estimation.
posterior.loss: Loss function (e.g.,mse)posterior.beta_mode: Temperature schedule (e.g.,1_over_log_n)posterior.beta0: Base temperature parameterposterior.gamma: Localization strength around w*
Controls the parameters for the MCMC samplers.
sampler.chains: Number of parallel chains.sampler.hmc.draws: Number of HMC draws.sampler.hmc.warmup: Number of HMC warmup steps.sampler.sgld.steps: Total number of SGLD steps.sampler.sgld.step_size: SGLD step size.sampler.sgld.batch_size: Minibatch size for SGLD.sampler.mclmc.draws: Number of MCLMC draws.
Example:
uv run python train.py sampler.chains=8 sampler.hmc.draws=5000 sampler.sgld.step_size=1e-6