Skip to content

Commit 620d093

Browse files
fix(kfpytorch): add environment property to pass elastic config via env vars
This fixes bug #1 where overriding nproc_per_node via with_overrides would not change the number of processes for single-node elastic tasks. For single-node elastic tasks (task_type='python-task'), the _execute method reads PET_NPROC_PER_NODE, PET_NNODES, PET_MAX_RESTARTS, and PET_MONITOR_INTERVAL from environment variables. However, these env vars were never being set in the task template during serialization. The fix adds an environment property override to PytorchElasticFunctionTask that includes the elastic config as environment variables. This ensures that when task_config is modified via with_overrides, the elastic configuration is correctly passed to the pod via environment variables. Combined with the previous fix (dynamic task_type property), this now fully supports: - Bug #1: single-node (1 proc) -> single-node (multiple procs) override - Bug #2: single-node (1 proc) -> multi-node (multiple procs) override Co-Authored-By: carlos@exa.ai <carlosmarques.personal@gmail.com>
1 parent 506d94e commit 620d093

2 files changed

Lines changed: 76 additions & 0 deletions

File tree

plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,33 @@ def task_type(self) -> str:
364364
"""
365365
return self._ELASTIC_TASK_TYPE_STANDALONE if self._task_config.nnodes == 1 else self._ELASTIC_TASK_TYPE
366366

367+
@property
368+
def environment(self) -> Dict[str, str]:
369+
"""
370+
Dynamically compute environment variables based on current _task_config.
371+
372+
This property overrides the base class's environment to include elastic-specific
373+
configuration as environment variables. This ensures that when task_config is
374+
modified via with_overrides (e.g., in a dynamic task), the elastic configuration
375+
is correctly passed to the pod via environment variables.
376+
377+
For single-node elastic tasks (task_type="python-task"), the _execute method reads
378+
these environment variables to configure torch elastic's LaunchConfig. Without this
379+
override, the environment variables would not be set, and _execute would fall back
380+
to the original _task_config values from the decorator, ignoring any overrides.
381+
382+
This fixes a bug where overriding nproc_per_node via with_overrides would not
383+
change the number of processes for single-node elastic tasks.
384+
"""
385+
env = self._environment.copy() if self._environment else {}
386+
# Always include elastic config in environment variables so that _execute
387+
# can read the current (potentially overridden) values
388+
env["PET_NNODES"] = str(self._task_config.nnodes)
389+
env["PET_NPROC_PER_NODE"] = str(self._task_config.nproc_per_node)
390+
env["PET_MAX_RESTARTS"] = str(self._task_config.max_restarts)
391+
env["PET_MONITOR_INTERVAL"] = str(self._task_config.monitor_interval)
392+
return env
393+
367394
def _execute(self, **kwargs) -> Any:
368395
"""
369396
Execute the task function using torch distributed's `elastic_launch`.

plugins/flytekit-kf-pytorch/tests/test_elastic_task.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,52 @@ def elastic_range_task():
361361
# Update to single-node
362362
elastic_range_task._task_config = Elastic(nnodes=1, nproc_per_node=1)
363363
assert elastic_range_task.task_type == "python-task"
364+
365+
366+
def test_environment_includes_elastic_config() -> None:
367+
"""Test that environment property includes elastic config as environment variables.
368+
369+
This test verifies the fix for a bug where overriding nproc_per_node via with_overrides
370+
would not change the number of processes for single-node elastic tasks. The fix adds
371+
elastic config to environment variables so that _execute can read the current
372+
(potentially overridden) values.
373+
"""
374+
# Create a task with single-node config
375+
@task(task_config=Elastic(nnodes=1, nproc_per_node=1, max_restarts=3, monitor_interval=10))
376+
def single_node_task():
377+
pass
378+
379+
# Check that environment includes elastic config
380+
env = single_node_task.environment
381+
assert env["PET_NNODES"] == "1"
382+
assert env["PET_NPROC_PER_NODE"] == "1"
383+
assert env["PET_MAX_RESTARTS"] == "3"
384+
assert env["PET_MONITOR_INTERVAL"] == "10"
385+
386+
# Simulate what with_overrides does: update _task_config
387+
single_node_task._task_config = Elastic(nnodes=1, nproc_per_node=4, max_restarts=5, monitor_interval=20)
388+
389+
# After updating _task_config, environment should reflect the new values
390+
env = single_node_task.environment
391+
assert env["PET_NNODES"] == "1"
392+
assert env["PET_NPROC_PER_NODE"] == "4"
393+
assert env["PET_MAX_RESTARTS"] == "5"
394+
assert env["PET_MONITOR_INTERVAL"] == "20"
395+
396+
397+
def test_environment_preserves_existing_env_vars() -> None:
398+
"""Test that environment property preserves existing environment variables."""
399+
# Create a task with custom environment variables
400+
@task(
401+
task_config=Elastic(nnodes=1, nproc_per_node=2),
402+
environment={"CUSTOM_VAR": "custom_value", "ANOTHER_VAR": "another_value"}
403+
)
404+
def task_with_env():
405+
pass
406+
407+
# Check that both custom and elastic env vars are present
408+
env = task_with_env.environment
409+
assert env["CUSTOM_VAR"] == "custom_value"
410+
assert env["ANOTHER_VAR"] == "another_value"
411+
assert env["PET_NNODES"] == "1"
412+
assert env["PET_NPROC_PER_NODE"] == "2"

0 commit comments

Comments
 (0)