Skip to content

Commit 9a8069b

Browse files
committed
Merge branch 'main' into ci
# Conflicts: # pyproject.toml
2 parents 50d4d0d + 8523095 commit 9a8069b

38 files changed

Lines changed: 1278 additions & 484 deletions

.readthedocs.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ build:
1414
sphinx:
1515
configuration: docs/conf.py
1616

17-
# Optionally, but recommended,
18-
# declare the Python requirements required to build your documentation
19-
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
20-
# python:
21-
# install:
22-
# - requirements: docs/requirements.txt
17+
python:
18+
install:
19+
- method: pip
20+
path: .
21+
extra_requirements:
22+
- dev

README.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ We view attribution as a counterfactual question: **_If we "unlearned" this trai
66
## Core features
77

88
- Gradient store for serial queries. We provide collection-time gradient compression for efficient storage, and integrate with FAISS for fast KNN search over large stores.
9-
- On-the-fly queries. Query uncompressed gradients without disk I/O overhead via a single pass over a dataset with a set of precomputed query gradients.
9+
- On-the-fly queries. Query gradients without compression or disk I/O overhead via a single pass over a dataset with a set of precomputed query gradients.
1010
- Experiment with multiple query strategies based on [LESS](https://arxiv.org/pdf/2402.04333).
1111
- Train‑time gradient collection. Capture gradients produced during training with a ~17% performance overhead.
1212
- Scalable. We use [FSDP2](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html), BitsAndBytes, and other performance optimizations to support large models, datasets, and clusters.
@@ -39,15 +39,15 @@ pip install bergson
3939
# Quickstart
4040

4141
```
42-
python -m bergson build runs/quickstart --model EleutherAI/pythia-14m --dataset NeelNanda/pile-10k --truncation
42+
bergson build runs/quickstart --model EleutherAI/pythia-14m --dataset NeelNanda/pile-10k --truncation
4343
```
4444

4545
# Usage
4646

4747
You can build an index of gradients for each training sample from the command line, using `bergson` as a CLI tool:
4848

4949
```bash
50-
python -m bergson build <output_path> --model <model_name> --dataset <dataset_name>
50+
bergson build <output_path> --model <model_name> --dataset <dataset_name>
5151
```
5252

5353
This will create a directory at `<output_path>` containing the gradients for each training sample in the specified dataset. The `--model` and `--dataset` arguments should be compatible with the Hugging Face `transformers` library. By default it assumes that the dataset has a `text` column, but you can specify other columns using `--prompt_column` and optionally `--completion_column`. The `--help` flag will show you all available options.
@@ -61,10 +61,16 @@ At the lowest level of abstraction, the `GradientCollector` context manager allo
6161

6262
## On-the-fly Query
6363

64-
You can query a large dataset without first building an index, by specifying a previously built index to query against:
64+
You can score a large dataset against a previously built query index without saving its gradients to disk:
6565

6666
```bash
67-
python -m bergson query <output_path> --model <model_name> --dataset <dataset_name> --query_path <existing_index_path> --scores_path <output_path> --score mean --save_index False
67+
bergson score <output_path> --model <model_name> --dataset <dataset_name> --query_path <existing_index_path> --score mean --projection_dim 0
68+
```
69+
70+
We provide a utility to reduce a dataset into its mean or sum query gradient, for use as a query index:
71+
72+
```bash
73+
bergson reduce <output_path> --model <model_name> --dataset <dataset_name> --method mean --unit_normalize --projection_dim 0
6874
```
6975

7076
## Index Query
@@ -144,7 +150,7 @@ collect_gradients(
144150
Where a reward signal is available we compute gradients using a weighted advantage estimate based on Dr. GRPO:
145151

146152
```bash
147-
python -m bergson build <output_path> --model <model_name> --dataset <dataset_name> --reward_column <reward_column_name>
153+
bergson build <output_path> --model <model_name> --dataset <dataset_name> --reward_column <reward_column_name>
148154
```
149155

150156
# Development

bergson/__init__.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
__version__ = "0.2.0"
22

3-
from .attributor import Attributor
43
from .collection import collect_gradients
5-
from .data import AttentionConfig, DataConfig, IndexConfig, load_gradients
6-
from .faiss_index import FaissConfig
4+
from .config import (
5+
AttentionConfig,
6+
DataConfig,
7+
IndexConfig,
8+
QueryConfig,
9+
ReduceConfig,
10+
ScoreConfig,
11+
)
12+
from .data import load_gradients
713
from .gradcheck import FiniteDiff
814
from .gradients import GradientCollector, GradientProcessor
9-
from .score_writer import MemmapScoreWriter
15+
from .query.attributor import Attributor
16+
from .query.faiss_index import FaissConfig
17+
from .score.scorer import Scorer
1018

1119
__all__ = [
1220
"collect_gradients",
@@ -19,5 +27,8 @@
1927
"IndexConfig",
2028
"DataConfig",
2129
"AttentionConfig",
22-
"MemmapScoreWriter",
30+
"Scorer",
31+
"ScoreConfig",
32+
"ReduceConfig",
33+
"QueryConfig",
2334
]

bergson/__main__.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,101 @@
1-
import os
21
from dataclasses import dataclass
32
from typing import Optional, Union
43

54
from simple_parsing import ArgumentParser, ConflictResolution
65

76
from .build import build
8-
from .data import IndexConfig, QueryConfig
9-
from .query import query
7+
from .config import IndexConfig, QueryConfig, ReduceConfig, ScoreConfig
8+
from .query.query_index import query
9+
from .reduce import reduce
10+
from .score.score import score_dataset
1011

1112

1213
@dataclass
1314
class Build:
14-
"""Build the gradient dataset."""
15+
"""Build a gradient index."""
1516

16-
cfg: IndexConfig
17+
index_cfg: IndexConfig
1718

1819
def execute(self):
19-
"""Build the gradient dataset."""
20-
if not self.cfg.save_index and self.cfg.skip_preconditioners:
21-
raise ValueError(
22-
"Either save_index must be True or skip_preconditioners must be False"
20+
"""Build the gradient index."""
21+
if self.index_cfg.skip_index and self.index_cfg.skip_preconditioners:
22+
raise ValueError("Either skip_index or skip_preconditioners must be False")
23+
24+
build(self.index_cfg)
25+
26+
27+
@dataclass
28+
class Reduce:
29+
"""Reduce a gradient index."""
30+
31+
index_cfg: IndexConfig
32+
33+
reduce_cfg: ReduceConfig
34+
35+
def execute(self):
36+
"""Reduce a gradient index."""
37+
if self.index_cfg.projection_dim != 0:
38+
print(
39+
"Warning: projection_dim is not 0. "
40+
"Compressed gradients will be reduced."
2341
)
2442

25-
build(self.cfg)
43+
reduce(self.index_cfg, self.reduce_cfg)
2644

2745

2846
@dataclass
29-
class Query:
30-
"""Query the gradient dataset."""
47+
class Score:
48+
"""Score a dataset against an existing gradient index."""
3149

32-
query_cfg: QueryConfig
50+
score_cfg: ScoreConfig
3351

3452
index_cfg: IndexConfig
3553

3654
def execute(self):
37-
"""Query the gradient dataset."""
38-
assert self.query_cfg.scores_path
39-
assert self.query_cfg.query_path
40-
41-
if os.path.exists(self.index_cfg.run_path) and self.index_cfg.save_index:
42-
raise ValueError(
43-
"Index path already exists and save_index is True - "
44-
"running this query will overwrite the existing gradients. "
45-
"If you meant to query the existing gradients use "
46-
"Attributor instead."
55+
"""Score a dataset against an existing gradient index."""
56+
assert self.score_cfg.query_path
57+
58+
if self.index_cfg.projection_dim != 0:
59+
print(
60+
"Warning: projection_dim is not 0. "
61+
"Compressed gradients will be scored."
4762
)
4863

49-
query(self.index_cfg, self.query_cfg)
64+
score_dataset(self.index_cfg, self.score_cfg)
65+
66+
67+
@dataclass
68+
class Query:
69+
"""Query an existing gradient index."""
70+
71+
query_cfg: QueryConfig
72+
73+
def execute(self):
74+
"""Query an existing gradient index."""
75+
query(self.query_cfg)
5076

5177

5278
@dataclass
5379
class Main:
5480
"""Routes to the subcommands."""
5581

56-
command: Union[Build, Query]
82+
command: Union[Build, Query, Reduce, Score]
5783

5884
def execute(self):
5985
"""Run the script."""
6086
self.command.execute()
6187

6288

63-
def main(args: Optional[list[str]] = None):
89+
def get_parser():
90+
"""Get the argument parser. Used for documentation generation."""
6491
parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT)
6592
parser.add_arguments(Main, dest="prog")
93+
return parser
94+
95+
96+
def main(args: Optional[list[str]] = None):
97+
"""Parse CLI arguments and dispatch to the selected subcommand."""
98+
parser = get_parser()
6699
prog: Main = parser.parse_args(args=args).prog
67100
prog.execute()
68101

bergson/build.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from tqdm.auto import tqdm
1111

1212
from bergson.collection import collect_gradients
13-
from bergson.data import IndexConfig, allocate_batches
13+
from bergson.config import IndexConfig
14+
from bergson.data import allocate_batches
1415
from bergson.utils import assert_type
1516
from bergson.worker_utils import setup_model_and_peft
1617

@@ -24,6 +25,20 @@ def build_worker(
2425
cfg: IndexConfig,
2526
ds: Dataset | IterableDataset,
2627
):
28+
"""
29+
Build worker executed per rank to collect gradients to populate the index.
30+
31+
Parameters
32+
----------
33+
rank : int
34+
Distributed rank / GPU ID for this worker.
35+
world_size : int
36+
Total number of workers participating in the run.
37+
cfg : IndexConfig
38+
Specifies the model, tokenizer, PEFT adapters, and other settings.
39+
ds : Dataset | IterableDataset
40+
The entire dataset to be indexed. A subset is assigned to each worker.
41+
"""
2742
torch.cuda.set_device(rank)
2843

2944
# These should be set by the main process
@@ -85,13 +100,22 @@ def flush(kwargs):
85100
processor.save(cfg.partial_run_path)
86101

87102

88-
def build(cfg: IndexConfig):
89-
cfg.partial_run_path.mkdir(parents=True, exist_ok=True)
90-
with (cfg.partial_run_path / "index_config.json").open("w") as f:
91-
json.dump(asdict(cfg), f, indent=2)
103+
def build(index_cfg: IndexConfig):
104+
"""
105+
Build a gradient index by distributing work across all available GPUs.
92106
93-
ds = setup_data_pipeline(cfg)
107+
Parameters
108+
----------
109+
index_cfg : IndexConfig
110+
Specifies the run path, dataset, model, tokenizer, PEFT adapters,
111+
and many other gradient collection settings.
112+
"""
113+
index_cfg.partial_run_path.mkdir(parents=True, exist_ok=True)
114+
with (index_cfg.partial_run_path / "index_config.json").open("w") as f:
115+
json.dump(asdict(index_cfg), f, indent=2)
94116

95-
launch_distributed_run("build", build_worker, [cfg, ds])
117+
ds = setup_data_pipeline(index_cfg)
96118

97-
shutil.move(cfg.partial_run_path, cfg.run_path)
119+
launch_distributed_run("build", build_worker, [index_cfg, ds])
120+
121+
shutil.move(index_cfg.partial_run_path, index_cfg.run_path)

0 commit comments

Comments
 (0)