|
2 | 2 |
|
3 | 3 | from gigl.common.logger import Logger |
4 | 4 | from gigl.orchestration.kubeflow.runner import ( |
| 5 | + _assert_required_flags, |
5 | 6 | _get_parser, |
6 | 7 | _parse_additional_job_args, |
7 | 8 | _parse_labels, |
@@ -66,6 +67,48 @@ def test_parse_args_from_cli(self): |
66 | 67 | self.assertEqual(parsed_args, expected_parsed_args) |
67 | 68 | self.assertEqual(parsed_labels, expected_parsed_labels) |
68 | 69 |
|
| 70 | + def test_assert_required_flags_missing_value(self): |
| 71 | + """Test that _assert_required_flags raises ValueError when a required flag has no value.""" |
| 72 | + parser = _get_parser() |
| 73 | + # Parse with RUN action but task_config_uri not provided (will be None) |
| 74 | + args = parser.parse_args( |
| 75 | + [ |
| 76 | + "--action=run", |
| 77 | + "--resource_config_uri=gs://bucket/resource_config.yaml", |
| 78 | + "--task_config_uri=", |
| 79 | + ] |
| 80 | + ) |
| 81 | + with self.assertRaises(ValueError): |
| 82 | + _assert_required_flags(args) |
| 83 | + |
| 84 | + def test_assert_required_flags_none_value(self): |
| 85 | + """Test that _assert_required_flags raises ValueError when a required flag has no value.""" |
| 86 | + parser = _get_parser() |
| 87 | + # Parse with RUN action but task_config_uri not provided (will be None) |
| 88 | + args = parser.parse_args( |
| 89 | + [ |
| 90 | + "--action=run", |
| 91 | + "--resource_config_uri=gs://bucket/resource_config.yaml", |
| 92 | + ] |
| 93 | + ) |
| 94 | + with self.assertRaises(ValueError): |
| 95 | + _assert_required_flags(args) |
| 96 | + |
| 97 | + def test_assert_required_flags_success(self): |
| 98 | + """Test that _assert_required_flags succeeds when all required flags are present.""" |
| 99 | + parser = _get_parser() |
| 100 | + # Parse with RUN action and all required flags |
| 101 | + args = parser.parse_args( |
| 102 | + [ |
| 103 | + "--action=run", |
| 104 | + "--task_config_uri=gs://bucket/task_config.yaml", |
| 105 | + "--resource_config_uri=gs://bucket/resource_config.yaml", |
| 106 | + ] |
| 107 | + ) |
| 108 | + |
| 109 | + # Should not raise any exception |
| 110 | + _assert_required_flags(args) |
| 111 | + |
69 | 112 |
|
70 | 113 | if __name__ == "__main__": |
71 | 114 | unittest.main() |
0 commit comments