Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ comment.cut
# Code editors
.idea/
.vscode/
.env
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ repos:
- id: check-toml
- id: check-xml
- id: check-yaml
args: [--unsafe]
exclude: |
(?x)^(
test/resources/config/config_with_duplicate_parameters_3.yaml
Expand Down
6 changes: 6 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ This can be nested arbitrarily deeply (be aware of combinatorial explosion of th

If a parameter is defined in (at least) two **different blocks** in `[grid, random, fixed]` on the same level, `seml` will throw an error to avoid ambiguity.
If a parameter is re-defined in a sub-configuration, the redefinition overrides any previous definitions of that parameter.
To remove a key inherited from a lower-priority config instead of overriding it, set it to `!remove`:
```yaml
large_datasets:
fixed:
regularization: !remove # removes the key set in the root fixed block
```

### Grid parameters
In an experiment config, under `grid` you can define parameters that should be sampled from a regular grid. Currently supported
Expand Down
1 change: 1 addition & 0 deletions src/seml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
from seml.evaluation import * # noqa
from seml.experiment import Experiment # noqa
from seml.experiment.observers import * # noqa
from seml.utils import REMOVE # noqa

__version__ = importlib.metadata.version(__package__ or __name__)
67 changes: 64 additions & 3 deletions src/seml/experiment/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,61 @@ def check_slurm_config(experiments_per_job: int, sbatch_options: SBatchOptions):
)


# Each inner list is a set of mutually exclusive option groups. When the higher-priority
# config sets any key from one group, keys from all other groups are removed from the base.
# Keys are compared after stripping leading dashes to handle --cpus-per-task and cpus-per-task.
_SBATCH_MUTUALLY_EXCLUSIVE: list[list[frozenset[str]]] = [
# --cpus-per-gpu is not compatible with --cpus-per-task (-c)
[frozenset({'cpus-per-task', 'c'}), frozenset({'cpus-per-gpu'})],
# --mem, --mem-per-cpu and --mem-per-gpu are mutually exclusive
[frozenset({'mem'}), frozenset({'mem-per-cpu'}), frozenset({'mem-per-gpu'})],
# --exclusive and --oversubscribe (-s) are mutually exclusive
[frozenset({'exclusive'}), frozenset({'oversubscribe', 's'})],
# --core-spec (-S) and --thread-spec are mutually exclusive
[frozenset({'core-spec', 'S'}), frozenset({'thread-spec'})],
# --ntasks-per-gpu is not compatible with --gpus-per-task, --gpus-per-socket, or
# --ntasks-per-node. Modelled as separate pairs because those three are not necessarily
# incompatible with each other.
[frozenset({'ntasks-per-gpu'}), frozenset({'gpus-per-task'})],
[frozenset({'ntasks-per-gpu'}), frozenset({'gpus-per-socket'})],
[frozenset({'ntasks-per-gpu'}), frozenset({'ntasks-per-node'})],
]


def _merge_sbatch_options(
base: dict[str, Any], override: dict[str, Any]
) -> SBatchOptions:
"""merge_dicts for sbatch options with automatic mutual-exclusion cleanup.

When override sets a key that belongs to a mutually exclusive group (e.g. cpus-per-gpu),
any keys from conflicting groups that were inherited from base (e.g. cpus-per-task) are
removed from the result, mirroring normal override precedence.
"""
result: dict[str, Any] = dict(merge_dicts(base, override))
norm = str.lstrip # strip leading dashes for key comparison
override_normalized = {norm(k, '-') for k in override}

for exclusive_groups in _SBATCH_MUTUALLY_EXCLUSIVE:
activated = {
i for i, g in enumerate(exclusive_groups) if override_normalized & g
}
if not activated:
continue
conflicting = {
s for j, g in enumerate(exclusive_groups) if j not in activated for s in g
}
to_remove = [
k for k in result if norm(k, '-') in conflicting and k not in override
]
for k in to_remove:
logging.info(
f"Removed inherited sbatch option '{k}' because it conflicts with an override."
)
del result[k]

return cast(SBatchOptions, result)


def assemble_slurm_config_dict(experiment_slurm_config: SlurmConfig):
"""
Realize inheritance for the slurm configuration, with the following relationship:
Expand Down Expand Up @@ -1274,13 +1329,19 @@ def assemble_slurm_config_dict(experiment_slurm_config: SlurmConfig):
raise ConfigError(
f"sbatch options template '{sbatch_options_template}' not found in settings.py."
)
slurm_config_base['sbatch_options'] = merge_dicts(
slurm_config_base['sbatch_options'],
SETTINGS.SBATCH_OPTIONS_TEMPLATES[sbatch_options_template],
slurm_config_base['sbatch_options'] = _merge_sbatch_options(
dict(slurm_config_base['sbatch_options']),
dict(SETTINGS.SBATCH_OPTIONS_TEMPLATES[sbatch_options_template]),
)

# Integrate experiment specific config
exp_sbatch_options = dict(slurm_config.get('sbatch_options', {}))
slurm_config = merge_dicts(slurm_config_base, slurm_config)
if exp_sbatch_options:
slurm_config['sbatch_options'] = _merge_sbatch_options(
dict(slurm_config_base['sbatch_options']),
exp_sbatch_options,
)

slurm_config['sbatch_options'] = cast(
SBatchOptions,
Expand Down
25 changes: 24 additions & 1 deletion src/seml/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,24 @@ def chunker(seq: S, size: int) -> Generator[S]:
yield from (cast(S, seq[pos : pos + size]) for pos in range(0, len(seq), size))


class _RemoveSentinel:
"""Sentinel that removes a key during dict merging. Use ``REMOVE`` (or ``!remove`` in YAML)
as a config value to unset a key inherited from a lower-priority config."""

_instance: _RemoveSentinel | None = None

def __new__(cls) -> _RemoveSentinel:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __repr__(self) -> str:
return 'REMOVE'


REMOVE = _RemoveSentinel()


D = TypeVar('D', bound=Mapping)


Expand All @@ -283,6 +301,9 @@ def merge_dicts(dict1: Mapping, dict2: Mapping) -> Mapping:
value, this will call itself recursively to merge these dictionaries.
This does not modify the input dictionaries (creates an internal copy).

Setting a value to ``REMOVE`` (or ``!remove`` in YAML) in dict2 will remove that key
from the result even if it was present in dict1.

Parameters
----------
dict1: dict
Expand All @@ -304,7 +325,9 @@ def merge_dicts(dict1: Mapping, dict2: Mapping) -> Mapping:
return_dict = copy.deepcopy(dict1)

for k, v in dict2.items():
if k not in dict1:
if isinstance(v, _RemoveSentinel):
return_dict.pop(k, None)
elif k not in dict1:
return_dict[k] = v
else:
if isinstance(v, dict) and isinstance(dict1[k], dict):
Expand Down
6 changes: 6 additions & 0 deletions src/seml/utils/yaml.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import yaml

from seml.utils import REMOVE
from seml.utils.errors import ConfigError


Expand Down Expand Up @@ -30,6 +31,11 @@ def construct_mapping(loader, node, deep=False):
construct_mapping,
)

YamlUniqueLoader.add_constructor(
'!remove',
lambda loader, node: REMOVE,
)


class YamlDumper(yaml.Dumper):
def represent_mapping(self, tag, mapping, flow_style=None):
Expand Down
10 changes: 10 additions & 0 deletions test/resources/config/config_slurm_cpus_per_gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
seml:
executable: test_config.py
name: example_experiment
output_dir: logs
project_root_dir: ../..

slurm:
- sbatch_options_template: GPU
sbatch_options:
cpus-per-gpu: 4
9 changes: 9 additions & 0 deletions test/resources/config/config_slurm_mem_per_cpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
seml:
executable: test_config.py
name: example_experiment
output_dir: logs
project_root_dir: ../..

slurm:
- sbatch_options:
mem-per-cpu: 4G
9 changes: 9 additions & 0 deletions test/resources/config/config_slurm_remove_sentinel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
seml:
executable: test_config.py
name: example_experiment
output_dir: logs
project_root_dir: ../..

slurm:
- sbatch_options:
mem: !remove
38 changes: 38 additions & 0 deletions test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def tearDownClass(cls):
CONFIG_RESOLVE_INTERPOLATION = (
"resources/config/config_resolve_with_interpolation.yaml"
)
CONFIG_SLURM_REMOVE_SENTINEL = "resources/config/config_slurm_remove_sentinel.yaml"
CONFIG_SLURM_CPUS_PER_GPU = "resources/config/config_slurm_cpus_per_gpu.yaml"
CONFIG_SLURM_MEM_PER_CPU = "resources/config/config_slurm_mem_per_cpu.yaml"

EXPERIMENT_RESOLVE_CONFIG = "resources/scripts/experiment_resolve_config.py"
EXPERIMENT_RESOLVE_INTERPOLATION = (
Expand Down Expand Up @@ -422,3 +425,38 @@ def test_config_experiments_error(self):
slurm_conf = config.read_config(self.CONFIG_SLURM_EXPERIMENTS_AND_TASKS)[1]
with self.assertRaises(ConfigError):
assemble_slurm_config_dict(slurm_conf[0])

def test_remove_sentinel_in_sbatch_options(self):
# !remove in an experiment's sbatch_options should delete the inherited key.
_, slurm_configs, _ = read_config(self.CONFIG_SLURM_REMOVE_SENTINEL)
slurm_config = assemble_slurm_config_dict(slurm_configs[0])
self.assertNotIn("mem", slurm_config["sbatch_options"])
# All other default keys should still be present.
self.assertIn("time", slurm_config["sbatch_options"])
self.assertIn("cpus-per-task", slurm_config["sbatch_options"])

def test_mutually_exclusive_cpus_per_gpu_removes_cpus_per_task(self):
# Setting cpus-per-gpu in the experiment config should automatically remove
# cpus-per-task inherited from the GPU template (and base defaults).
_, slurm_configs, _ = read_config(self.CONFIG_SLURM_CPUS_PER_GPU)
slurm_config = assemble_slurm_config_dict(slurm_configs[0])
sbatch = slurm_config["sbatch_options"]
self.assertIn("cpus-per-gpu", sbatch)
self.assertEqual(sbatch["cpus-per-gpu"], 4)
self.assertNotIn("cpus-per-task", sbatch)
# Non-conflicting keys from the template and defaults should survive.
self.assertIn("mem", sbatch)
self.assertIn("gres", sbatch)

def test_mutually_exclusive_mem_per_cpu_removes_mem(self):
# Setting mem-per-cpu in the experiment config should automatically remove
# mem inherited from the base defaults.
_, slurm_configs, _ = read_config(self.CONFIG_SLURM_MEM_PER_CPU)
slurm_config = assemble_slurm_config_dict(slurm_configs[0])
sbatch = slurm_config["sbatch_options"]
self.assertIn("mem-per-cpu", sbatch)
self.assertEqual(sbatch["mem-per-cpu"], "4G")
self.assertNotIn("mem", sbatch)
# cpus-per-task and time from defaults should still be there.
self.assertIn("cpus-per-task", sbatch)
self.assertIn("time", sbatch)
Loading
Loading