Skip to content

Commit 0c3093f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d1131ba commit 0c3093f

8 files changed

Lines changed: 24 additions & 33 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Data attribution methods estimate the effect on a behavior of interest of removi
55

66
## Core features
77

8-
Per-token and per-sequence attribution is available everywhere. On-disk gradient stores and on-the-fly queries are supported. Almost every feature is available through both the CLI and a programmatic interface, which use a shared set of configuration dataclasses. Configuration dataclasses are always serialized to disk so commands can be reproduced in one line. To understand every available configuration option, [check out the documentation](https://bergson.readthedocs.io/en/latest/api.html#bergson.IndexConfig).
8+
Per-token and per-sequence attribution is available everywhere. On-disk gradient stores and on-the-fly queries are supported. Almost every feature is available through both the CLI and a programmatic interface, which use a shared set of configuration dataclasses. Configuration dataclasses are always serialized to disk so commands can be reproduced in one line. To understand every available configuration option, [check out the documentation](https://bergson.readthedocs.io/en/latest/api.html#bergson.IndexConfig).
99

1010
Bergson uses FSDP2 or SimpleFSDP, BitsAndBytes, and low-level performance optimizations to support large models, datasets, and clusters. Bergson integrates with HuggingFace Transformers and Datasets, and also supports on-disk datasets in a variety of formats.
1111

bergson/approx_unrolling/approx_unrolling_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import torch
1010
from torch import Tensor
1111

12+
from bergson.cli.commands import Score
13+
from bergson.cli.config_io import save_run_config
1214
from bergson.config import (
1315
ApproxUnrollingConfig,
1416
DistributedConfig,
@@ -17,8 +19,6 @@
1719
ScoreConfig,
1820
)
1921
from bergson.data import load_scores
20-
from bergson.cli.commands import Score
21-
from bergson.cli.config_io import save_run_config
2222
from bergson.distributed import init_dist, launch_distributed_run
2323
from bergson.hessians.apply_hessian import EkfacApplicator, EkfacConfig
2424
from bergson.score.score import score_dataset

bergson/cli/commands.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""CLI command definitions.
22
33
Each command is a thin dataclass that validates, persists its ``config.yaml``,
4-
and dispatches. They live here rather than in ``__main__`` so pipelines can
5-
lazily import lower-level commands for step serialization without importing
4+
and dispatches. They live here rather than in ``__main__`` so pipelines can
5+
lazily import lower-level commands for step serialization without importing
66
the corresponding CLI entrypoint.
77
"""
88

@@ -82,7 +82,10 @@ def execute(self):
8282
"if skip_index is True HessianConfig.method must be provided"
8383
)
8484

85-
if self.hessian_cfg is not None and self.hessian_cfg.method != "autocorrelation":
85+
if (
86+
self.hessian_cfg is not None
87+
and self.hessian_cfg.method != "autocorrelation"
88+
):
8689
raise ValueError(
8790
f"build only supports autocorrelation Hessians, got "
8891
f"'{self.hessian_cfg.method}'. Use the `hessian` command for "
@@ -224,7 +227,6 @@ def execute(self):
224227
score_dataset(self.index_cfg, self.score_cfg, self.preprocess_cfg)
225228

226229

227-
228230
@dataclass
229231
class Trackstar(Serializable):
230232
"""Run hessians, build, and score as a single pipeline."""

bergson/cli/config_io.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
"""IO for Bergson config files (run with ``bergson path_to.yaml``.)"""
22

3-
from typing import TypeVar, Protocol, cast
43
import subprocess
54
from datetime import datetime, timezone
65
from importlib.metadata import PackageNotFoundError
76
from importlib.metadata import version as _pkg_version
87
from pathlib import Path
9-
from typing import Any
8+
from typing import Any, Protocol, TypeVar, cast
109

1110
import yaml
1211

13-
1412
CONFIG_FILENAME = "config.yaml"
1513

1614

@@ -58,10 +56,7 @@ def _write(
5856
run_path: str | Path | None = None,
5957
) -> Path:
6058
"""Write a ``{[run_path], steps, metadata}`` document, metadata last."""
61-
doc = {
62-
"steps": steps,
63-
"metadata": make_metadata()
64-
}
59+
doc = {"steps": steps, "metadata": make_metadata()}
6560
if run_path is not None:
6661
doc["run_path"] = str(run_path)
6762

@@ -71,9 +66,7 @@ def _write(
7166
return path
7267

7368

74-
def save_run_config(
75-
command: Any, run_dir: str | Path, *, name: str | None = None
76-
):
69+
def save_run_config(command: Any, run_dir: str | Path, *, name: str | None = None):
7770
"""Write a one-step component ``config.yaml`` for ``command`` into ``run_dir``.
7871
7972
It can be run using ``bergson <run_dir>/config.yaml``.
@@ -164,6 +157,5 @@ def load_subconfig(
164157
sub = cmd_dict.get(field)
165158
if sub is None:
166159
return default
167-
168-
return cast(T, config_cls.from_dict(sub))
169160

161+
return cast(T, config_cls.from_dict(sub))

bergson/cli/config_runner.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,11 @@ def _parse_steps(
3030
f"Unknown command '{cmd_name}'. "
3131
f"Valid commands: {sorted(command_registry)}."
3232
) from None
33-
33+
3434
# Hydrate config
3535
parsed_step = cmd_cls.from_dict(cmd_dict or {}, drop_extra_fields=False)
3636

37-
parsed.append((
38-
cmd_name,
39-
parsed_step
40-
))
37+
parsed.append((cmd_name, parsed_step))
4138
return parsed
4239

4340

@@ -49,11 +46,11 @@ def run_config(config_path: str, command_registry: dict[str, type]) -> None:
4946
Each step also writes its own component ``config.yaml`` into its run directory.
5047
"""
5148
doc = read_config(config_path)
52-
49+
5350
# Top-level run path for a multi-step pipeline
5451
run_path = doc.get("run_path")
5552
steps = _parse_steps(doc["steps"], command_registry)
56-
53+
5754
multi = len(steps) > 1
5855

5956
if multi:
@@ -65,7 +62,7 @@ def run_config(config_path: str, command_registry: dict[str, type]) -> None:
6562
"logging pipeline config to %s",
6663
run_path,
6764
)
68-
65+
6966
resolved_steps = [{name: cmd.to_dict()} for name, cmd in steps]
7067
save_pipeline_config(resolved_steps, run_path)
7168

bergson/process_grads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def normalize_flat_grad(
4949

5050
def assert_autocorrelation_hessian(path: Path) -> None:
5151
"""Verify that ``path`` contains an autocorrelation hessian."""
52-
hessian_cfg = load_subconfig(path, 'hessian', HessianConfig)
53-
52+
hessian_cfg = load_subconfig(path, "hessian", HessianConfig)
53+
5454
assert hessian_cfg is not None
5555
assert hessian_cfg.method == "autocorrelation", (
5656
f"Hessian at '{path}' was computed with method "

bergson/trackstar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def _validate(cfg: IndexConfig):
136136
trackstar_cfg.score_cfg.higher_is_better = True
137137
_validate(score_index_cfg)
138138
save_run_config(
139-
Score(trackstar_cfg.score_cfg, score_index_cfg, trackstar_cfg.preprocess_cfg),
139+
Score(
140+
trackstar_cfg.score_cfg, score_index_cfg, trackstar_cfg.preprocess_cfg
141+
),
140142
score_index_cfg.partial_run_path,
141143
)
142144
score_dataset(

tests/ekfac_tests/test_fim_accuracy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,7 @@ def test_kfac_fim_accuracy(seq_lengths, num_batches, max_rel_error, sample, tmp_
150150
path=str(index_cfg.partial_run_path),
151151
)
152152

153-
hessian_cfg = HessianConfig(
154-
method="autocorrelation", use_dataset_labels=not sample
155-
)
153+
hessian_cfg = HessianConfig(method="autocorrelation", use_dataset_labels=not sample)
156154

157155
computer = CollectorComputer(
158156
model=model,

0 commit comments

Comments
 (0)