Skip to content

Commit 372c59d

Browse files
committed
feat: [100] Added basic functionality for grid search
1 parent cf4c3e3 commit 372c59d

4 files changed

Lines changed: 93 additions & 4 deletions

File tree

src/fenn/args/parser.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23
from typing import Any, Dict
34

45
import yaml
@@ -15,11 +16,11 @@ def __new__(cls, *args, **kwargs):
1516
cls._instance = super().__new__(cls)
1617
return cls._instance
1718

18-
def __init__(self) -> None:
19+
def __init__(self, config_file: str | Path = "fenn.yaml") -> None:
1920
if hasattr(self, "_initialized"):
2021
return
2122

22-
self._config_file: str = "fenn.yaml"
23+
self._config_file: Path = Path(config_file)
2324
self._args: Dict[str, Any] = {}
2425

2526
self._keystore: KeyStore = KeyStore()
@@ -48,7 +49,7 @@ def load_configuration(self) -> Any:
4849
# File exists → load YAML
4950
with open(self._config_file) as f:
5051
self._args = yaml.safe_load(f)
51-
self._args["project"] = self._config_file.split("/")[-1].split(".")[0]
52+
self._args["project"] = self._config_file.stem
5253

5354
return self._args
5455

@@ -64,7 +65,7 @@ def config_file(self) -> str:
6465

6566
@config_file.setter
6667
def config_file(self, config_file: str) -> None:
67-
self._config_file = config_file
68+
self._config_file: Path = Path(config_file)
6869

6970
@property
7071
def args(self) -> Dict[str, Any]:

src/fenn/cli/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import fenn.cli.list as list
66
import fenn.cli.pull as pull
77
import fenn.cli.run as run
8+
import fenn.cli.grid as grid
89

910

1011
def build_parser() -> argparse.ArgumentParser:
@@ -145,6 +146,21 @@ def build_parser() -> argparse.ArgumentParser:
145146
"--profile", default=None, help="Profile name (default: 'default')"
146147
)
147148
p_logout.set_defaults(func=auth.execute)
149+
150+
# ========= GRID =========
151+
152+
p_grid = subparsers.add_parser(
153+
"grid",
154+
help="Run a Fenn project several times, with all possible grid hyperparams",
155+
)
156+
157+
p_grid.add_argument(
158+
"main",
159+
nargs="?",
160+
help="Name of main.py file (default 'main.py')",
161+
)
162+
163+
p_grid.set_defaults(func=grid.execute)
148164

149165
return parser
150166

src/fenn/cli/grid.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import argparse
2+
from itertools import product
3+
import os
4+
import subprocess
5+
import sys
6+
import shutil
7+
8+
from colorama import Fore, Style
9+
from pathlib import Path
10+
11+
import yaml
12+
13+
from fenn.args.parser import Parser
14+
15+
16+
17+
def execute(args: argparse.Namespace) -> None:
18+
"""
19+
Execute the fenn grid command to train a model with different seeds, epoch counts, or learning rates.
20+
21+
Args:
22+
args: Parsed command-line arguments containing:
23+
- path: Target directory (default: current directory)
24+
"""
25+
26+
main_path: Path = Path(args.main).resolve() if args.main else Path.cwd() / "main.py"
27+
yaml_path: Path = main_path.parent / "fenn.yaml"
28+
yaml_copy: Path = main_path.parent / "fenn_copy.yaml"
29+
try:
30+
parsed_grid: list[dict] = _parse_grid(yaml_path=yaml_path)
31+
except TemplateError as e:
32+
print(f"{Fore.RED}Template error: missing grid section{e}{Style.RESET_ALL}")
33+
sys.exit(1)
34+
shutil.copy(yaml_path, yaml_copy)
35+
try:
36+
for hyperparameter in parsed_grid:
37+
_execute_fenn(hyperparameter=hyperparameter,
38+
main_path=main_path,
39+
yaml_path=yaml_path)
40+
finally:
41+
shutil.copy(yaml_copy, yaml_path)
42+
os.remove(yaml_copy)
43+
44+
def _build_variants(raw_grid: dict[str, list | int]) -> list[dict[str, int]]:
45+
keys = raw_grid.keys()
46+
values = [v if isinstance(v, list) else [v] for v in raw_grid.values()]
47+
return [dict(zip(keys, combo)) for combo in product(*values)]
48+
49+
def _parse_grid(yaml_path: Path) -> list[dict[str, int]]:
50+
parsed_yaml = Parser(config_file=yaml_path).load_configuration()
51+
if parsed_yaml.get("grid") is None:
52+
raise TemplateError
53+
return _build_variants(raw_grid=parsed_yaml.get("grid").get("train"))
54+
55+
def _execute_fenn(hyperparameter: dict[str: int],
56+
main_path: Path,
57+
yaml_path: Path) -> None:
58+
with open(yaml_path, "r", encoding="utf-8") as f:
59+
config = yaml.safe_load(f)
60+
train_data = config["train"]
61+
for key, value in hyperparameter.items():
62+
train_data.update({key:value})
63+
with open(yaml_path, "w", encoding="utf-8") as f:
64+
yaml.dump(config, f, allow_unicode=True, default_flow_style=False)
65+
subprocess.run(["python3", main_path])
66+
67+
68+
class TemplateError(Exception):
69+
"""Raised when a template has an invalid structure."""
70+
71+
pass

src/tests/cli/test_grid_command.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"Test for cli grid command"

0 commit comments

Comments
 (0)