From ed3be955db562b0890f8cdba422acc3de0621854 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Thu, 24 Jul 2025 21:06:11 +0000 Subject: [PATCH] Add TOML support for inputs file in pyflyte run command Co-authored-by: adriano --- flytekit/clis/sdk_in_container/run.py | 26 ++++++++--- .../unit/cli/pyflyte/my_wf_input.toml | 45 +++++++++++++++++++ tests/flytekit/unit/cli/pyflyte/test_run.py | 11 +++++ 3 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 tests/flytekit/unit/cli/pyflyte/my_wf_input.toml diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 44141a5cc1..6d16b2f70b 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -1,4 +1,5 @@ import asyncio +import functools import importlib import inspect import json @@ -8,11 +9,16 @@ import tempfile import typing import typing as t +from contextlib import redirect_stdout from dataclasses import dataclass, field, fields from typing import Iterator, get_args import rich_click as click import yaml +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib from click import Context from mashumaro.codecs.json import JSONEncoder from rich.progress import Progress, TextColumn, TimeElapsedColumn @@ -957,7 +963,7 @@ def __init__( ["--inputs-file"], required=False, type=click.Path(exists=True, dir_okay=False, resolve_path=True), - help="Path to a YAML | JSON file containing inputs for the workflow.", + help="Path to a YAML | JSON | TOML file containing inputs for the workflow.", ) ) super().__init__(name=name, params=params, callback=callback, help=help) @@ -971,12 +977,18 @@ def load_inputs(f: str) -> t.Dict[str, str]: try: inputs = json.loads(f) except json.JSONDecodeError as e: - raise click.BadParameter( - message=f"Could not load the inputs file. Please make sure it is a valid JSON or YAML file." - f"\n json error: {e}," - f"\n yaml error: {yaml_e}", - param_hint="--inputs-file", - ) + json_e = e + try: + inputs = tomllib.loads(f) + except Exception as e: + toml_e = e + raise click.BadParameter( + message=f"Could not load the inputs file. Please make sure it is a valid JSON, YAML, or TOML file." + f"\n json error: {json_e}," + f"\n yaml error: {yaml_e}," + f"\n toml error: {toml_e}", + param_hint="--inputs-file", + ) return inputs diff --git a/tests/flytekit/unit/cli/pyflyte/my_wf_input.toml b/tests/flytekit/unit/cli/pyflyte/my_wf_input.toml new file mode 100644 index 0000000000..47f007dfdd --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/my_wf_input.toml @@ -0,0 +1,45 @@ +a = 1 +b = "Hello" +c = 1.1 +e = [1, 2, 3] +g = "tests/flytekit/unit/cli/pyflyte/testdata/df.parquet" +h = true +i = "2020-05-01" +j = "20H" +k = "RED" +p = "None" +q = "tests/flytekit/unit/cli/pyflyte/testdata" +remote = "tests/flytekit/unit/cli/pyflyte/testdata" +image = "tests/flytekit/unit/cli/pyflyte/testdata" + +[d] +i = 1 +a = ["h", "e"] + +[f] +x = 1.0 +y = 2.0 + +[l] +hello = "world" + +[m] +a = "b" +c = "d" + +[[n]] +x = "tests/flytekit/unit/cli/pyflyte/testdata/df.parquet" + +[o] +x = ["tests/flytekit/unit/cli/pyflyte/testdata/df.parquet"] + +[[r]] +i = 1 +a = ["h", "e"] + +[s.x] +i = 1 +a = ["h", "e"] + +[t] +i = [{i = 1, a = ["h", "e"]}] \ No newline at end of file diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 848dbbf6e1..685c418750 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -299,6 +299,17 @@ def test_all_types_with_yaml_input(): assert result.exit_code == 0, result.stdout +def test_all_types_with_toml_input(): + runner = CliRunner() + + result = runner.invoke( + pyflyte.main, + ["run", os.path.join(DIR_NAME, "workflow.py"), "my_wf", "--inputs-file", os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.toml")], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + + def test_all_types_with_pipe_input(monkeypatch): runner = CliRunner() input= str(json.load(open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"),"r")))