|
1 | | -import os |
2 | 1 | from dataclasses import dataclass |
3 | 2 | from typing import Optional, Union |
4 | 3 |
|
5 | 4 | from simple_parsing import ArgumentParser, ConflictResolution |
6 | 5 |
|
7 | 6 | from .build import build |
8 | | -from .data import IndexConfig, QueryConfig |
9 | | -from .query import query |
| 7 | +from .config import IndexConfig, QueryConfig, ReduceConfig, ScoreConfig |
| 8 | +from .query.query_index import query |
| 9 | +from .reduce import reduce |
| 10 | +from .score.score import score_dataset |
10 | 11 |
|
11 | 12 |
|
12 | 13 | @dataclass |
13 | 14 | class Build: |
14 | | - """Build the gradient dataset.""" |
| 15 | + """Build a gradient index.""" |
15 | 16 |
|
16 | | - cfg: IndexConfig |
| 17 | + index_cfg: IndexConfig |
17 | 18 |
|
18 | 19 | def execute(self): |
19 | | - """Build the gradient dataset.""" |
20 | | - if not self.cfg.save_index and self.cfg.skip_preconditioners: |
21 | | - raise ValueError( |
22 | | - "Either save_index must be True or skip_preconditioners must be False" |
| 20 | + """Build the gradient index.""" |
| 21 | + if self.index_cfg.skip_index and self.index_cfg.skip_preconditioners: |
| 22 | + raise ValueError("Either skip_index or skip_preconditioners must be False") |
| 23 | + |
| 24 | + build(self.index_cfg) |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class Reduce: |
| 29 | + """Reduce a gradient index.""" |
| 30 | + |
| 31 | + index_cfg: IndexConfig |
| 32 | + |
| 33 | + reduce_cfg: ReduceConfig |
| 34 | + |
| 35 | + def execute(self): |
| 36 | + """Reduce a gradient index.""" |
| 37 | + if self.index_cfg.projection_dim != 0: |
| 38 | + print( |
| 39 | + "Warning: projection_dim is not 0. " |
| 40 | + "Compressed gradients will be reduced." |
23 | 41 | ) |
24 | 42 |
|
25 | | - build(self.cfg) |
| 43 | + reduce(self.index_cfg, self.reduce_cfg) |
26 | 44 |
|
27 | 45 |
|
28 | 46 | @dataclass |
29 | | -class Query: |
30 | | - """Query the gradient dataset.""" |
| 47 | +class Score: |
| 48 | + """Score a dataset against an existing gradient index.""" |
31 | 49 |
|
32 | | - query_cfg: QueryConfig |
| 50 | + score_cfg: ScoreConfig |
33 | 51 |
|
34 | 52 | index_cfg: IndexConfig |
35 | 53 |
|
36 | 54 | def execute(self): |
37 | | - """Query the gradient dataset.""" |
38 | | - assert self.query_cfg.scores_path |
39 | | - assert self.query_cfg.query_path |
40 | | - |
41 | | - if os.path.exists(self.index_cfg.run_path) and self.index_cfg.save_index: |
42 | | - raise ValueError( |
43 | | - "Index path already exists and save_index is True - " |
44 | | - "running this query will overwrite the existing gradients. " |
45 | | - "If you meant to query the existing gradients use " |
46 | | - "Attributor instead." |
| 55 | + """Score a dataset against an existing gradient index.""" |
| 56 | + assert self.score_cfg.query_path |
| 57 | + |
| 58 | + if self.index_cfg.projection_dim != 0: |
| 59 | + print( |
| 60 | + "Warning: projection_dim is not 0. " |
| 61 | + "Compressed gradients will be scored." |
47 | 62 | ) |
48 | 63 |
|
49 | | - query(self.index_cfg, self.query_cfg) |
| 64 | + score_dataset(self.index_cfg, self.score_cfg) |
| 65 | + |
| 66 | + |
| 67 | +@dataclass |
| 68 | +class Query: |
| 69 | + """Query an existing gradient index.""" |
| 70 | + |
| 71 | + query_cfg: QueryConfig |
| 72 | + |
| 73 | + def execute(self): |
| 74 | + """Query an existing gradient index.""" |
| 75 | + query(self.query_cfg) |
50 | 76 |
|
51 | 77 |
|
52 | 78 | @dataclass |
53 | 79 | class Main: |
54 | 80 | """Routes to the subcommands.""" |
55 | 81 |
|
56 | | - command: Union[Build, Query] |
| 82 | + command: Union[Build, Query, Reduce, Score] |
57 | 83 |
|
58 | 84 | def execute(self): |
59 | 85 | """Run the script.""" |
60 | 86 | self.command.execute() |
61 | 87 |
|
62 | 88 |
|
63 | | -def main(args: Optional[list[str]] = None): |
| 89 | +def get_parser(): |
| 90 | + """Get the argument parser. Used for documentation generation.""" |
64 | 91 | parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT) |
65 | 92 | parser.add_arguments(Main, dest="prog") |
| 93 | + return parser |
| 94 | + |
| 95 | + |
| 96 | +def main(args: Optional[list[str]] = None): |
| 97 | + """Parse CLI arguments and dispatch to the selected subcommand.""" |
| 98 | + parser = get_parser() |
66 | 99 | prog: Main = parser.parse_args(args=args).prog |
67 | 100 | prog.execute() |
68 | 101 |
|
|
0 commit comments