Skip to content

Commit 987a5a9

Browse files
committed
Move files around
1 parent 77bceec commit 987a5a9

28 files changed

Lines changed: 167 additions & 142 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Data attribution methods estimate the effect on a behavior of interest of removi
55

66
## Core features
77

8-
Per-token and per-sequence attribution is available everywhere. On-disk gradient stores and on-the-fly queries are supported. Almost every feature is available through both the CLI and a programmatic interface, which use a shared set of configuration dataclasses. Configuration dataclasses are always serialized to disk so commands can be reproduced in one line. To understand every available configuration option, [check out the documentation](https://bergson.readthedocs.io/en/latest/api.html#bergson.IndexConfig).
8+
Per-token and per-sequence attribution is available everywhere. On-disk gradient stores and on-the-fly queries are supported. Almost every feature is available through both the CLI and a programmatic interface, which use a shared set of configuration dataclasses. Configuration dataclasses are always serialized to disk so commands can be reproduced in one line. To understand every available configuration option, [check out the documentation](https://bergson.readthedocs.io/en/latest/api.html#bergson.IndexConfig).
99

1010
Bergson uses FSDP2 or SimpleFSDP, BitsAndBytes, and low-level performance optimizations to support large models, datasets, and clusters. Bergson integrates with HuggingFace Transformers and Datasets, and also supports on-disk datasets in a variety of formats.
1111

bergson/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .collector.collector import CollectorComputer
88
from .collector.gradient_collectors import GradientCollector
99
from .collector.in_memory_collector import InMemoryCollector
10-
from .config import (
10+
from .config.config import (
1111
AttentionConfig,
1212
DataConfig,
1313
IndexConfig,

bergson/__main__.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
from dataclasses import dataclass
44
from typing import Union, get_args
55

6+
import petname
67
from simple_parsing import ArgumentParser, ConflictResolution
78

9+
from bergson.config.config_io import parse_steps, read_config, save_pipeline_config
10+
from bergson.utils.logger import get_logger
11+
812
from .cli.commands import (
913
ApproxUnrolling,
1014
Build,
@@ -20,7 +24,6 @@
2024
Train,
2125
Validate,
2226
)
23-
from .cli.config_runner import run_config
2427

2528

2629
@dataclass
@@ -48,6 +51,41 @@ def execute(self):
4851
self.command.execute()
4952

5053

54+
def run_config(config_path: str, command_registry: dict[str, type]) -> None:
55+
"""Execute each step of a bergson config YAML in order.
56+
57+
A fully resolved version of any multi-step config (including default values)
58+
is written to ``run_path`` (auto-named under ``runs/`` if not given).
59+
Each step also writes its own component ``config.yaml`` into its run directory.
60+
"""
61+
doc = read_config(config_path)
62+
63+
# Optional top-level run path for a multi-step pipeline
64+
run_path = doc.get("run_path")
65+
66+
steps = parse_steps(doc["steps"], command_registry)
67+
68+
multi = len(steps) > 1
69+
70+
if multi:
71+
if not run_path:
72+
run_name = petname.generate(2, separator="_")
73+
run_path = f"runs/{run_name}"
74+
get_logger(__name__).warning(
75+
"No top level run_path set for this multi-step YAML; "
76+
"logging pipeline config to %s",
77+
run_path,
78+
)
79+
80+
resolved_steps = [{name: cmd.to_dict()} for name, cmd in steps]
81+
save_pipeline_config(resolved_steps, run_path)
82+
83+
for i, (cmd_name, cmd) in enumerate(steps, start=1):
84+
if multi:
85+
print(f"\n[pipeline] step {i}/{len(steps)}: {cmd_name}")
86+
cmd.execute()
87+
88+
5189
def main():
5290
"""Parse CLI arguments and dispatch to the selected subcommand.
5391
@@ -66,9 +104,8 @@ def main():
66104
command_classes = get_args(Main.__dataclass_fields__["command"].type)
67105
command_registry = {cls.__name__.lower(): cls for cls in command_classes}
68106

69-
# Config-file mode: the YAML file is self-describing, so no command verb is
70-
# needed. Accept it as the sole argument; leading command words (e.g.
71-
# `bergson build run/config.yaml`) are ignored.
107+
# Config-file mode: accept a YAML file as the sole argument.
108+
# Leading command words (e.g. `bergson build run/config.yaml`) are ignored.
72109
config_path: str | None = None
73110
if len(args) == 1 and os.path.isfile(args[0]):
74111
config_path = args[0]
@@ -79,7 +116,7 @@ def main():
79116
run_config(config_path, command_registry)
80117
return
81118

82-
# CLI-flag mode: standard argparse-style flag parsing.
119+
# CLI-flag mode: argparse-style flag parsing.
83120
parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT)
84121
parser.add_arguments(Main, dest="prog")
85122
prog: Main = parser.parse_args().prog

bergson/approx_unrolling/approx_unrolling_math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
import torch
1010
from torch import Tensor
1111

12-
from bergson.config import (
12+
from bergson.cli.commands import Score
13+
from bergson.config.config import (
1314
ApproxUnrollingConfig,
1415
DistributedConfig,
1516
IndexConfig,
1617
PreprocessConfig,
1718
ScoreConfig,
1819
)
20+
from bergson.config.config_io import save_run_config
1921
from bergson.data import load_scores
20-
from bergson.cli.commands import Score
21-
from bergson.cli.config_io import save_run_config
2222
from bergson.distributed import init_dist, launch_distributed_run
2323
from bergson.hessians.apply_hessian import EkfacApplicator, EkfacConfig
2424
from bergson.score.score import score_dataset

bergson/approx_unrolling/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030

3131
from ..build import build
3232
from ..cli.commands import Build
33-
from ..cli.config_io import save_run_config
3433
from ..config import (
3534
ApproxUnrollingConfig,
3635
HessianConfig,
3736
IndexConfig,
3837
PreprocessConfig,
3938
)
39+
from ..config.config_io import save_run_config
4040
from ..utils.logger import get_logger
4141
from .approx_unrolling_math import (
4242
compute_lr_times_steps_per_segment,

bergson/approx_unrolling/precompute_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
CollectorComputer,
99
fwd_bwd_hessian_factory,
1010
)
11-
from bergson.config import (
11+
from bergson.config.config import (
1212
ApproxUnrollingConfig,
1313
HessianConfig,
1414
IndexConfig,

bergson/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tqdm.auto import tqdm
99

1010
from bergson.collection import collect_gradients
11-
from bergson.config import HessianConfig, IndexConfig, PreprocessConfig
11+
from bergson.config.config import HessianConfig, IndexConfig, PreprocessConfig
1212
from bergson.data import allocate_batches
1313
from bergson.distributed import (
1414
cap_world_size_to_dataset,

bergson/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.distributed as dist
66
from datasets import Dataset
77

8-
from .config import PreprocessConfig
8+
from .config.config import PreprocessConfig
99
from .data import compute_num_token_grads, create_index, create_token_index
1010
from .process_grads import (
1111
get_trackstar_hessian,

bergson/cli/commands.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""CLI command definitions.
22
33
Each command is a thin dataclass that validates, persists its ``config.yaml``,
4-
and dispatches. They live here rather than in ``__main__`` so pipelines can
5-
lazily import lower-level commands for step serialization without importing
4+
and dispatches. They live here rather than in ``__main__`` so pipelines can
5+
lazily import lower-level commands for step serialization without importing
66
the corresponding CLI entrypoint.
77
"""
88

@@ -11,7 +11,7 @@
1111
from simple_parsing import Serializable
1212

1313
from ..build import build
14-
from ..config import (
14+
from ..config.config import (
1515
ApproxUnrollingConfig,
1616
HessianConfig,
1717
HessianPipelineConfig,
@@ -24,14 +24,14 @@
2424
TrainingConfig,
2525
ValidationConfig,
2626
)
27+
from ..config.config_io import save_run_config
2728
from ..diagnose import DiagnoseConfig, diagnose
2829
from ..hessians.hessian_approximations import approximate_hessians
2930
from ..magic import MagicConfig, run_magic
3031
from ..process_grads import mix_autocorrelation_matrices
3132
from ..query.query_index import query
3233
from ..score.score import score_dataset
3334
from ..utils.worker_utils import validate_run_path
34-
from .config_io import save_run_config
3535

3636

3737
@dataclass
@@ -82,7 +82,10 @@ def execute(self):
8282
"if skip_index is True HessianConfig.method must be provided"
8383
)
8484

85-
if self.hessian_cfg is not None and self.hessian_cfg.method != "autocorrelation":
85+
if (
86+
self.hessian_cfg is not None
87+
and self.hessian_cfg.method != "autocorrelation"
88+
):
8689
raise ValueError(
8790
f"build only supports autocorrelation Hessians, got "
8891
f"'{self.hessian_cfg.method}'. Use the `hessian` command for "
@@ -224,7 +227,6 @@ def execute(self):
224227
score_dataset(self.index_cfg, self.score_cfg, self.preprocess_cfg)
225228

226229

227-
228230
@dataclass
229231
class Trackstar(Serializable):
230232
"""Run hessians, build, and score as a single pipeline."""
@@ -234,7 +236,7 @@ class Trackstar(Serializable):
234236
trackstar_cfg: TrackstarConfig
235237

236238
def execute(self):
237-
from ..trackstar import trackstar
239+
from .trackstar import trackstar
238240

239241
save_run_config(self, self.index_cfg.run_path)
240242
trackstar(self.index_cfg, self.trackstar_cfg)

bergson/cli/config_runner.py

Lines changed: 0 additions & 75 deletions
This file was deleted.

0 commit comments

Comments
 (0)