Skip to content

Commit dd7e694

Browse files
No public description
PiperOrigin-RevId: 881732289
1 parent ef3db3e commit dd7e694

3 files changed

Lines changed: 109 additions & 3 deletions

File tree

official/core/train_utils.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ class ParseConfigOptions:
292292
tpu: str = ''
293293
tf_data_service: str = ''
294294
params_override: str = ''
295+
strict_override: bool = True
295296

296297
def __contains__(self, name):
297298
return name in dataclasses.asdict(self)
@@ -330,9 +331,13 @@ def base_experiment(self):
330331

331332
def parse_config_file(self, params):
332333
"""Override the configs of params from the config_file."""
334+
is_strict = True
335+
if isinstance(self._flags_obj, ParseConfigOptions):
336+
is_strict = self._flags_obj.strict_override
333337
for config_file in self._flags_obj.config_file or []:
334338
params = hyperparams.override_params_dict(
335-
params, config_file, is_strict=True)
339+
params, config_file, is_strict=is_strict
340+
)
336341
return params
337342

338343
def parse_runtime(self, params):
@@ -363,14 +368,28 @@ def parse_data_service(self, params):
363368
return params
364369

365370
def parse_params_override(self, params):
371+
"""Overrides params from the --params_override flag.
372+
373+
Args:
374+
params: A ParamsDict object to be overridden.
375+
376+
Returns:
377+
The overridden ParamsDict object.
378+
"""
366379
# Get the second level of override from `--params_override`.
367380
# `--params_override` is typically used as a further override over the
368381
# template. For example, one may define a particular template for training
369382
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
370383
# then define different learning rates and pass it via `--params_override`.
371384
if self._flags_obj.params_override:
385+
is_strict = True
386+
if isinstance(self._flags_obj, ParseConfigOptions):
387+
is_strict = self._flags_obj.strict_override
372388
params = hyperparams.override_params_dict(
373-
params, self._flags_obj.params_override, is_strict=True)
389+
params,
390+
self._flags_obj.params_override,
391+
is_strict=is_strict,
392+
)
374393
return params
375394

376395

official/core/train_utils_test.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import os
1818
import pprint
1919

20+
from absl.testing import parameterized
2021
import numpy as np
2122
import tensorflow as tf, tf_keras
23+
import yaml
2224

2325
from official.core import exp_factory
2426
from official.core import test_utils
@@ -47,7 +49,7 @@ def foo():
4749
return experiment_config
4850

4951

50-
class TrainUtilsTest(tf.test.TestCase):
52+
class TrainUtilsTest(tf.test.TestCase, parameterized.TestCase):
5153

5254
def test_get_leaf_nested_dict(self):
5355
d = {'a': {'i': {'x': 5}}}
@@ -138,6 +140,87 @@ def test_construct_experiment_from_flags(self):
138140
self.assertEqual(params_from_obj.trainer.train_steps, 10)
139141
self.assertEqual(params_from_obj.trainer.validation_steps, 11)
140142

143+
@parameterized.named_parameters(
144+
dict(
145+
testcase_name='strict_with_extra',
146+
strict_override=True,
147+
has_extra=True,
148+
expect_error=True,
149+
),
150+
dict(
151+
testcase_name='non_strict_with_extra',
152+
strict_override=False,
153+
has_extra=True,
154+
expect_error=False,
155+
),
156+
dict(
157+
testcase_name='strict_no_extra',
158+
strict_override=True,
159+
has_extra=False,
160+
expect_error=False,
161+
),
162+
dict(
163+
testcase_name='non_strict_no_extra',
164+
strict_override=False,
165+
has_extra=False,
166+
expect_error=False,
167+
),
168+
)
169+
def test_parse_configuration_strict_override_file(
170+
self, strict_override, has_extra, expect_error
171+
):
172+
tempdir = self.create_tempdir().full_path
173+
config_path = os.path.join(tempdir, 'config.yaml')
174+
override_config = {
175+
'task': {
176+
'model': {
177+
'model_id': 'override',
178+
},
179+
},
180+
'trainer': {
181+
'train_steps': 500,
182+
},
183+
}
184+
if has_extra:
185+
override_config['task']['extra_key'] = 'extra_value'
186+
187+
with open(config_path, 'w') as f:
188+
yaml.dump(override_config, f)
189+
190+
options = train_utils.ParseConfigOptions(
191+
experiment='foo',
192+
config_file=[config_path],
193+
strict_override=strict_override,
194+
)
195+
196+
if expect_error:
197+
with self.assertRaises(KeyError):
198+
train_utils.parse_configuration(options)
199+
else:
200+
params = train_utils.parse_configuration(options)
201+
self.assertEqual(params.task.model.model_id, 'override')
202+
self.assertEqual(params.trainer.train_steps, 500)
203+
if has_extra:
204+
self.assertTrue(hasattr(params.task, 'extra_key'))
205+
self.assertEqual(params.task.extra_key, 'extra_value')
206+
else:
207+
self.assertFalse(hasattr(params.task, 'extra_key'))
208+
209+
def test_parse_configuration_strict_override_string(self):
210+
# Test non-strict loading with params_override string
211+
options_non_strict_override = train_utils.ParseConfigOptions(
212+
experiment='foo',
213+
config_file=[],
214+
params_override='task.another_extra=test,trainer.train_steps=100',
215+
strict_override=False,
216+
)
217+
params_override = train_utils.parse_configuration(
218+
options_non_strict_override
219+
)
220+
self.assertEqual(params_override.trainer.train_steps, 100)
221+
self.assertTrue(hasattr(params_override.task, 'another_extra'))
222+
self.assertEqual(params_override.task.another_extra, 'test')
223+
141224

142225
class BestCheckpointExporterTest(tf.test.TestCase):
143226

official/modeling/hyperparams/params_dict.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import collections
1818
import copy
19+
import logging
1920
import re
2021

2122
import six
@@ -192,6 +193,9 @@ def _override(self, override_dict, is_strict=True):
192193
'To extend the existing keys, use '
193194
'`override` with `is_strict` = False.'.format(k))
194195
else:
196+
logging.warning(
197+
'Adding new key `%s` to ParamsDict because is_strict=False.', k
198+
)
195199
self._set(k, v)
196200
else:
197201
if isinstance(v, dict):

0 commit comments

Comments
 (0)