|
| 1 | +import argparse |
| 2 | +from itertools import product |
| 3 | +import os |
| 4 | +import subprocess |
| 5 | +import sys |
| 6 | +import shutil |
| 7 | + |
| 8 | +from colorama import Fore, Style |
| 9 | +from pathlib import Path |
| 10 | + |
| 11 | +import yaml |
| 12 | + |
| 13 | +from fenn.args.parser import Parser |
| 14 | + |
| 15 | + |
| 16 | + |
| 17 | +def execute(args: argparse.Namespace) -> None: |
| 18 | + """ |
| 19 | + Execute the fenn grid command to train a model with different seeds, epoch counts, or learning rates. |
| 20 | +
|
| 21 | + Args: |
| 22 | + args: Parsed command-line arguments containing: |
| 23 | + - path: Target directory (default: current directory) |
| 24 | + """ |
| 25 | + |
| 26 | + main_path: Path = Path(args.main).resolve() if args.main else Path.cwd() / "main.py" |
| 27 | + yaml_path: Path = main_path.parent / "fenn.yaml" |
| 28 | + yaml_copy: Path = main_path.parent / "fenn_copy.yaml" |
| 29 | + try: |
| 30 | + parsed_grid: list[dict] = _parse_grid(yaml_path=yaml_path) |
| 31 | + except TemplateError as e: |
| 32 | + print(f"{Fore.RED}Template error: missing grid section{e}{Style.RESET_ALL}") |
| 33 | + sys.exit(1) |
| 34 | + shutil.copy(yaml_path, yaml_copy) |
| 35 | + try: |
| 36 | + for hyperparameter in parsed_grid: |
| 37 | + _execute_fenn(hyperparameter=hyperparameter, |
| 38 | + main_path=main_path, |
| 39 | + yaml_path=yaml_path) |
| 40 | + finally: |
| 41 | + shutil.copy(yaml_copy, yaml_path) |
| 42 | + os.remove(yaml_copy) |
| 43 | + |
| 44 | +def _build_variants(raw_grid: dict[str, list | int]) -> list[dict[str, int]]: |
| 45 | + keys = raw_grid.keys() |
| 46 | + values = [v if isinstance(v, list) else [v] for v in raw_grid.values()] |
| 47 | + return [dict(zip(keys, combo)) for combo in product(*values)] |
| 48 | + |
| 49 | +def _parse_grid(yaml_path: Path) -> list[dict[str, int]]: |
| 50 | + parsed_yaml = Parser(config_file=yaml_path).load_configuration() |
| 51 | + if parsed_yaml.get("grid") is None: |
| 52 | + raise TemplateError |
| 53 | + return _build_variants(raw_grid=parsed_yaml.get("grid").get("train")) |
| 54 | + |
| 55 | +def _execute_fenn(hyperparameter: dict[str: int], |
| 56 | + main_path: Path, |
| 57 | + yaml_path: Path) -> None: |
| 58 | + with open(yaml_path, "r", encoding="utf-8") as f: |
| 59 | + config = yaml.safe_load(f) |
| 60 | + train_data = config["train"] |
| 61 | + for key, value in hyperparameter.items(): |
| 62 | + train_data.update({key:value}) |
| 63 | + with open(yaml_path, "w", encoding="utf-8") as f: |
| 64 | + yaml.dump(config, f, allow_unicode=True, default_flow_style=False) |
| 65 | + subprocess.run(["python3", main_path]) |
| 66 | + |
| 67 | + |
| 68 | +class TemplateError(Exception): |
| 69 | + """Raised when a template has an invalid structure.""" |
| 70 | + |
| 71 | + pass |
0 commit comments