Skip to content

Commit f9508e8

Browse files
Be able to check for None flag and empty string flag (#395)
Co-authored-by: kmontemayor <kyle.e.montemayor@gmail.com>
1 parent 7a66a06 commit f9508e8

2 files changed

Lines changed: 47 additions & 1 deletion

File tree

python/gigl/orchestration/kubeflow/runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ def _assert_required_flags(args: argparse.Namespace) -> None:
146146
for flag in required_flags:
147147
if not hasattr(args, flag):
148148
missing_flags.append(flag)
149-
elif len(getattr(args, flag)) == 0:
149+
flag = getattr(args, flag)
150+
if flag is None:
151+
missing_flags.append(flag)
152+
elif not flag:
150153
missing_values.append(flag)
151154

152155
if missing_flags:

python/tests/unit/orchestration/kubeflow/kfp_runner_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from gigl.common.logger import Logger
44
from gigl.orchestration.kubeflow.runner import (
5+
_assert_required_flags,
56
_get_parser,
67
_parse_additional_job_args,
78
_parse_labels,
@@ -66,6 +67,48 @@ def test_parse_args_from_cli(self):
6667
self.assertEqual(parsed_args, expected_parsed_args)
6768
self.assertEqual(parsed_labels, expected_parsed_labels)
6869

70+
def test_assert_required_flags_missing_value(self):
71+
"""Test that _assert_required_flags raises ValueError when a required flag has no value."""
72+
parser = _get_parser()
73+
# Parse with RUN action but task_config_uri not provided (will be None)
74+
args = parser.parse_args(
75+
[
76+
"--action=run",
77+
"--resource_config_uri=gs://bucket/resource_config.yaml",
78+
"--task_config_uri=",
79+
]
80+
)
81+
with self.assertRaises(ValueError):
82+
_assert_required_flags(args)
83+
84+
def test_assert_required_flags_none_value(self):
85+
"""Test that _assert_required_flags raises ValueError when a required flag has no value."""
86+
parser = _get_parser()
87+
# Parse with RUN action but task_config_uri not provided (will be None)
88+
args = parser.parse_args(
89+
[
90+
"--action=run",
91+
"--resource_config_uri=gs://bucket/resource_config.yaml",
92+
]
93+
)
94+
with self.assertRaises(ValueError):
95+
_assert_required_flags(args)
96+
97+
def test_assert_required_flags_success(self):
98+
"""Test that _assert_required_flags succeeds when all required flags are present."""
99+
parser = _get_parser()
100+
# Parse with RUN action and all required flags
101+
args = parser.parse_args(
102+
[
103+
"--action=run",
104+
"--task_config_uri=gs://bucket/task_config.yaml",
105+
"--resource_config_uri=gs://bucket/resource_config.yaml",
106+
]
107+
)
108+
109+
# Should not raise any exception
110+
_assert_required_flags(args)
111+
69112

70113
if __name__ == "__main__":
71114
unittest.main()

0 commit comments

Comments
 (0)