Skip to content

Commit 2aa7c22

Browse files
committed
Move files around
1 parent c783df4 commit 2aa7c22

29 files changed

Lines changed: 367 additions & 294 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Data attribution methods estimate the effect on a behavior of interest of removi
55

66
## Core features
77

8-
Per-token and per-sequence attribution is available everywhere. On-disk gradient stores and on-the-fly queries are supported. Almost every feature is available through both the CLI and a programmatic interface, which use a shared set of configuration dataclasses. Configuration dataclasses are always serialized to disk so commands can be reproduced in one line. To understand every available configuration option, [check out the documentation](https://bergson.readthedocs.io/en/latest/api.html#bergson.IndexConfig).
8+
Per-token and per-sequence attribution is available everywhere. On-disk gradient stores and on-the-fly queries are supported. Almost every feature is available through both the CLI and a programmatic interface, which use a shared set of configuration dataclasses. Configuration dataclasses are always serialized to disk so commands can be reproduced in one line. To understand every available configuration option, [check out the documentation](https://bergson.readthedocs.io/en/latest/api.html#bergson.IndexConfig).
99

1010
Bergson uses FSDP2 or SimpleFSDP, BitsAndBytes, and low-level performance optimizations to support large models, datasets, and clusters. Bergson integrates with HuggingFace Transformers and Datasets, and also supports on-disk datasets in a variety of formats.
1111

bergson/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .collector.collector import CollectorComputer
88
from .collector.gradient_collectors import GradientCollector
99
from .collector.in_memory_collector import InMemoryCollector
10-
from .config import (
10+
from .config.config import (
1111
AttentionConfig,
1212
DataConfig,
1313
IndexConfig,

bergson/__main__.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from simple_parsing import ArgumentParser, ConflictResolution
77

8+
from bergson.config.config_io import parse_steps, read_config, save_pipeline_config
9+
810
from .cli.commands import (
911
ApproxUnrolling,
1012
Build,
@@ -20,7 +22,6 @@
2022
Train,
2123
Validate,
2224
)
23-
from .cli.config_runner import run_config
2425

2526

2627
@dataclass
@@ -48,6 +49,31 @@ def execute(self):
4849
self.command.execute()
4950

5051

52+
def run_config(config_path: str, command_registry: dict[str, type]) -> None:
53+
"""Execute each step of a bergson config YAML in order.
54+
55+
A fully resolved version of any multi-step config (including default values)
56+
is written to ``run_path`` (auto-named under ``runs/`` if not given).
57+
Each step also writes its own component ``config.yaml`` into its run directory.
58+
"""
59+
config = read_config(config_path)
60+
61+
steps = parse_steps(config["steps"], command_registry)
62+
63+
multi = len(steps) > 1
64+
65+
if multi:
66+
# Optional top-level run path for a multi-step pipeline
67+
run_path = config.get("run_path")
68+
69+
save_pipeline_config(steps, run_path)
70+
71+
for i, (cmd_name, cmd) in enumerate(steps, start=1):
72+
if multi:
73+
print(f"\n[pipeline] step {i}/{len(steps)}: {cmd_name}")
74+
cmd.execute()
75+
76+
5177
def main():
5278
"""Parse CLI arguments and dispatch to the selected subcommand.
5379
@@ -66,9 +92,8 @@ def main():
6692
command_classes = get_args(Main.__dataclass_fields__["command"].type)
6793
command_registry = {cls.__name__.lower(): cls for cls in command_classes}
6894

69-
# Config-file mode: the YAML file is self-describing, so no command verb is
70-
# needed. Accept it as the sole argument; leading command words (e.g.
71-
# `bergson build run/config.yaml`) are ignored.
95+
# Config-file mode: accept a YAML file as the sole argument.
96+
# Leading command words (e.g. `bergson build run/config.yaml`) are ignored.
7297
config_path: str | None = None
7398
if len(args) == 1 and os.path.isfile(args[0]):
7499
config_path = args[0]
@@ -79,7 +104,7 @@ def main():
79104
run_config(config_path, command_registry)
80105
return
81106

82-
# CLI-flag mode: standard argparse-style flag parsing.
107+
# CLI-flag mode: argparse-style flag parsing.
83108
parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT)
84109
parser.add_arguments(Main, dest="prog")
85110
prog: Main = parser.parse_args().prog

bergson/approx_unrolling/approx_unrolling_math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
import torch
1010
from torch import Tensor
1111

12-
from bergson.config import (
12+
from bergson.cli.commands import Score
13+
from bergson.config.config import (
1314
ApproxUnrollingConfig,
1415
DistributedConfig,
1516
IndexConfig,
1617
PreprocessConfig,
1718
ScoreConfig,
1819
)
20+
from bergson.config.config_io import save_run_config
1921
from bergson.data import load_scores
20-
from bergson.cli.commands import Score
21-
from bergson.cli.config_io import save_run_config
2222
from bergson.distributed import init_dist, launch_distributed_run
2323
from bergson.hessians.apply_hessian import EkfacApplicator, EkfacConfig
2424
from bergson.score.score import score_dataset

bergson/approx_unrolling/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030

3131
from ..build import build
3232
from ..cli.commands import Build
33-
from ..cli.config_io import save_run_config
3433
from ..config import (
3534
ApproxUnrollingConfig,
3635
HessianConfig,
3736
IndexConfig,
3837
PreprocessConfig,
3938
)
39+
from ..config.config_io import save_run_config
4040
from ..utils.logger import get_logger
4141
from .approx_unrolling_math import (
4242
compute_lr_times_steps_per_segment,

bergson/approx_unrolling/precompute_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
CollectorComputer,
99
fwd_bwd_hessian_factory,
1010
)
11-
from bergson.config import (
11+
from bergson.config.config import (
1212
ApproxUnrollingConfig,
1313
HessianConfig,
1414
IndexConfig,

bergson/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tqdm.auto import tqdm
99

1010
from bergson.collection import collect_gradients
11-
from bergson.config import HessianConfig, IndexConfig, PreprocessConfig
11+
from bergson.config.config import HessianConfig, IndexConfig, PreprocessConfig
1212
from bergson.data import allocate_batches
1313
from bergson.distributed import (
1414
cap_world_size_to_dataset,

bergson/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.distributed as dist
66
from datasets import Dataset
77

8-
from .config import PreprocessConfig
8+
from .config.config import PreprocessConfig
99
from .data import compute_num_token_grads, create_index, create_token_index
1010
from .process_grads import (
1111
get_trackstar_hessian,

bergson/cli/commands.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
"""CLI command definitions.
22
33
Each command is a thin dataclass that validates, persists its ``config.yaml``,
4-
and dispatches. They live here rather than in ``__main__`` so pipelines can
5-
lazily import lower-level commands for step serialization without importing
6-
the corresponding CLI entrypoint.
4+
and dispatches. They live here rather than in ``__main__`` so pipelines can
5+
import lower-level commands for config serialization without importing the
6+
corresponding CLI entrypoint.
77
"""
88

99
from dataclasses import dataclass
1010

1111
from simple_parsing import Serializable
1212

1313
from ..build import build
14-
from ..config import (
14+
from ..config.config import (
1515
ApproxUnrollingConfig,
1616
HessianConfig,
1717
HessianPipelineConfig,
@@ -24,14 +24,14 @@
2424
TrainingConfig,
2525
ValidationConfig,
2626
)
27+
from ..config.config_io import save_run_config
2728
from ..diagnose import DiagnoseConfig, diagnose
2829
from ..hessians.hessian_approximations import approximate_hessians
2930
from ..magic import MagicConfig, run_magic
3031
from ..process_grads import mix_autocorrelation_matrices
3132
from ..query.query_index import query
3233
from ..score.score import score_dataset
3334
from ..utils.worker_utils import validate_run_path
34-
from .config_io import save_run_config
3535

3636

3737
@dataclass
@@ -82,7 +82,10 @@ def execute(self):
8282
"if skip_index is True HessianConfig.method must be provided"
8383
)
8484

85-
if self.hessian_cfg is not None and self.hessian_cfg.method != "autocorrelation":
85+
if (
86+
self.hessian_cfg is not None
87+
and self.hessian_cfg.method != "autocorrelation"
88+
):
8689
raise ValueError(
8790
f"build only supports autocorrelation Hessians, got "
8891
f"'{self.hessian_cfg.method}'. Use the `hessian` command for "
@@ -224,7 +227,6 @@ def execute(self):
224227
score_dataset(self.index_cfg, self.score_cfg, self.preprocess_cfg)
225228

226229

227-
228230
@dataclass
229231
class Trackstar(Serializable):
230232
"""Run hessians, build, and score as a single pipeline."""
@@ -234,7 +236,7 @@ class Trackstar(Serializable):
234236
trackstar_cfg: TrackstarConfig
235237

236238
def execute(self):
237-
from ..trackstar import trackstar
239+
from .trackstar import trackstar
238240

239241
save_run_config(self, self.index_cfg.run_path)
240242
trackstar(self.index_cfg, self.trackstar_cfg)

bergson/cli/config_io.py

Lines changed: 0 additions & 163 deletions
This file was deleted.

0 commit comments

Comments
 (0)