|
17 | 17 | import os |
18 | 18 | import pprint |
19 | 19 |
|
| 20 | +from absl.testing import parameterized |
20 | 21 | import numpy as np |
21 | 22 | import tensorflow as tf, tf_keras |
| 23 | +import yaml |
22 | 24 |
|
23 | 25 | from official.core import exp_factory |
24 | 26 | from official.core import test_utils |
@@ -47,7 +49,7 @@ def foo(): |
47 | 49 | return experiment_config |
48 | 50 |
|
49 | 51 |
|
50 | | -class TrainUtilsTest(tf.test.TestCase): |
| 52 | +class TrainUtilsTest(tf.test.TestCase, parameterized.TestCase): |
51 | 53 |
|
52 | 54 | def test_get_leaf_nested_dict(self): |
53 | 55 | d = {'a': {'i': {'x': 5}}} |
@@ -138,6 +140,87 @@ def test_construct_experiment_from_flags(self): |
138 | 140 | self.assertEqual(params_from_obj.trainer.train_steps, 10) |
139 | 141 | self.assertEqual(params_from_obj.trainer.validation_steps, 11) |
140 | 142 |
|
| 143 | + @parameterized.named_parameters( |
| 144 | + dict( |
| 145 | + testcase_name='strict_with_extra', |
| 146 | + strict_override=True, |
| 147 | + has_extra=True, |
| 148 | + expect_error=True, |
| 149 | + ), |
| 150 | + dict( |
| 151 | + testcase_name='non_strict_with_extra', |
| 152 | + strict_override=False, |
| 153 | + has_extra=True, |
| 154 | + expect_error=False, |
| 155 | + ), |
| 156 | + dict( |
| 157 | + testcase_name='strict_no_extra', |
| 158 | + strict_override=True, |
| 159 | + has_extra=False, |
| 160 | + expect_error=False, |
| 161 | + ), |
| 162 | + dict( |
| 163 | + testcase_name='non_strict_no_extra', |
| 164 | + strict_override=False, |
| 165 | + has_extra=False, |
| 166 | + expect_error=False, |
| 167 | + ), |
| 168 | + ) |
| 169 | + def test_parse_configuration_strict_override_file( |
| 170 | + self, strict_override, has_extra, expect_error |
| 171 | + ): |
| 172 | + tempdir = self.create_tempdir().full_path |
| 173 | + config_path = os.path.join(tempdir, 'config.yaml') |
| 174 | + override_config = { |
| 175 | + 'task': { |
| 176 | + 'model': { |
| 177 | + 'model_id': 'override', |
| 178 | + }, |
| 179 | + }, |
| 180 | + 'trainer': { |
| 181 | + 'train_steps': 500, |
| 182 | + }, |
| 183 | + } |
| 184 | + if has_extra: |
| 185 | + override_config['task']['extra_key'] = 'extra_value' |
| 186 | + |
| 187 | + with open(config_path, 'w') as f: |
| 188 | + yaml.dump(override_config, f) |
| 189 | + |
| 190 | + options = train_utils.ParseConfigOptions( |
| 191 | + experiment='foo', |
| 192 | + config_file=[config_path], |
| 193 | + strict_override=strict_override, |
| 194 | + ) |
| 195 | + |
| 196 | + if expect_error: |
| 197 | + with self.assertRaises(KeyError): |
| 198 | + train_utils.parse_configuration(options) |
| 199 | + else: |
| 200 | + params = train_utils.parse_configuration(options) |
| 201 | + self.assertEqual(params.task.model.model_id, 'override') |
| 202 | + self.assertEqual(params.trainer.train_steps, 500) |
| 203 | + if has_extra: |
| 204 | + self.assertTrue(hasattr(params.task, 'extra_key')) |
| 205 | + self.assertEqual(params.task.extra_key, 'extra_value') |
| 206 | + else: |
| 207 | + self.assertFalse(hasattr(params.task, 'extra_key')) |
| 208 | + |
| 209 | + def test_parse_configuration_strict_override_string(self): |
| 210 | + # Test non-strict loading with params_override string |
| 211 | + options_non_strict_override = train_utils.ParseConfigOptions( |
| 212 | + experiment='foo', |
| 213 | + config_file=[], |
| 214 | + params_override='task.another_extra=test,trainer.train_steps=100', |
| 215 | + strict_override=False, |
| 216 | + ) |
| 217 | + params_override = train_utils.parse_configuration( |
| 218 | + options_non_strict_override |
| 219 | + ) |
| 220 | + self.assertEqual(params_override.trainer.train_steps, 100) |
| 221 | + self.assertTrue(hasattr(params_override.task, 'another_extra')) |
| 222 | + self.assertEqual(params_override.task.another_extra, 'test') |
| 223 | + |
141 | 224 |
|
142 | 225 | class BestCheckpointExporterTest(tf.test.TestCase): |
143 | 226 |
|
|
0 commit comments