Skip to content

Commit c552d6b

Browse files
author
Rishabh Devnani
committed
feat(pipeline): Make _PipelineExecution a public class
Renames _PipelineExecution to PipelineExecution and exports it from the workflow module. Keeps the old private name available for backward compatibility. Closes #4391 --- X-AI-Prompt: Make _PipelineExecution public per GitHub issue 4391 X-AI-Tool: kiro-cli
1 parent 02e864d commit c552d6b

File tree

3 files changed

+39
-21
lines changed

3 files changed

+39
-21
lines changed

sagemaker-mlops/src/sagemaker/mlops/workflow/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
__version__ = "0.1.0"
2020

2121
# Pipeline and configuration
22-
from sagemaker.mlops.workflow.pipeline import Pipeline, PipelineGraph
22+
from sagemaker.mlops.workflow.pipeline import Pipeline, PipelineGraph, PipelineExecution
2323
from sagemaker.mlops.workflow.pipeline_experiment_config import (
2424
PipelineExperimentConfig,
2525
PipelineExperimentConfigProperty,
@@ -74,6 +74,7 @@
7474
__all__ = [
7575
# Pipeline and configuration
7676
"Pipeline",
77+
"PipelineExecution",
7778
"PipelineGraph",
7879
"PipelineExperimentConfig",
7980
"PipelineExperimentConfigProperty",

sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def start(
406406
specified, uses the latest version ID.
407407
408408
Returns:
409-
A `_PipelineExecution` instance, if successful.
409+
A `PipelineExecution` instance, if successful.
410410
"""
411411
if selective_execution_config is not None:
412412
if (
@@ -438,7 +438,7 @@ def start(
438438
lambda: self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs),
439439
botocore_client_error_code="AccessDeniedException",
440440
)
441-
return _PipelineExecution(
441+
return PipelineExecution(
442442
arn=response["PipelineExecutionArn"],
443443
sagemaker_session=self.sagemaker_session,
444444
)
@@ -602,7 +602,7 @@ def _get_parameters_for_execution(self, pipeline_execution_arn: str) -> Dict[str
602602
Returns:
603603
A parameter dict from the execution.
604604
"""
605-
pipeline_execution = _PipelineExecution(
605+
pipeline_execution = PipelineExecution(
606606
arn=pipeline_execution_arn,
607607
sagemaker_session=self.sagemaker_session,
608608
)
@@ -950,8 +950,21 @@ def _generate_step_map(steps: Sequence[Step], step_map: dict):
950950

951951

952952
@attr.s
953-
class _PipelineExecution:
954-
"""Internal class for encapsulating pipeline execution instances.
953+
class PipelineExecution:
954+
"""Encapsulates a pipeline execution instance.
955+
956+
This class can be used to interact with pipeline executions that were
957+
started from any source (Python SDK, Studio UI, console, etc.).
958+
959+
Example::
960+
961+
execution = PipelineExecution(
962+
arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/my-pipeline/execution/abc123",
963+
sagemaker_session=sagemaker_session,
964+
)
965+
execution.describe()
966+
execution.wait()
967+
execution.list_steps()
955968
956969
Attributes:
957970
arn (str): The arn of the pipeline execution.
@@ -1087,6 +1100,10 @@ def result(self, step_name: str):
10871100
)
10881101

10891102

1103+
# Backward-compatible alias for the previously private class name
1104+
_PipelineExecution = PipelineExecution
1105+
1106+
10901107
def get_function_step_result(
10911108
step_name: str,
10921109
step_list: list,

sagemaker-mlops/tests/unit/workflow/test_pipeline.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_pipeline_get_latest_execution_arn_none(mock_session, mock_step):
175175

176176

177177
def test_pipeline_build_parameters_from_execution(mock_session, mock_step):
178-
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
178+
from sagemaker.mlops.workflow.pipeline import PipelineExecution
179179
pipeline = Pipeline(name="test-pipeline", steps=[mock_step], sagemaker_session=mock_session)
180180

181181
mock_session.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
@@ -268,43 +268,43 @@ def test_pipeline_delete_triggers_not_found(mock_session, mock_step):
268268

269269

270270
def test_pipeline_execution_stop(mock_session):
271-
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
271+
from sagemaker.mlops.workflow.pipeline import PipelineExecution
272272

273-
execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
273+
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
274274
execution.stop()
275275
mock_session.sagemaker_client.stop_pipeline_execution.assert_called_once()
276276

277277

278278
def test_pipeline_execution_describe(mock_session):
279-
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
279+
from sagemaker.mlops.workflow.pipeline import PipelineExecution
280280

281-
execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
281+
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
282282
execution.describe()
283283
mock_session.sagemaker_client.describe_pipeline_execution.assert_called_once()
284284

285285

286286
def test_pipeline_execution_list_steps(mock_session):
287-
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
287+
from sagemaker.mlops.workflow.pipeline import PipelineExecution
288288

289289
mock_session.sagemaker_client.list_pipeline_execution_steps.return_value = {"PipelineExecutionSteps": []}
290-
execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
290+
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
291291
result = execution.list_steps()
292292
assert result == []
293293

294294

295295
def test_pipeline_execution_list_parameters(mock_session):
296-
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
296+
from sagemaker.mlops.workflow.pipeline import PipelineExecution
297297

298-
execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
298+
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
299299
execution.list_parameters(max_results=10, next_token="token")
300300
mock_session.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_once()
301301

302302

303303
def test_pipeline_execution_wait(mock_session):
304-
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
304+
from sagemaker.mlops.workflow.pipeline import PipelineExecution
305305
import botocore.waiter
306306

307-
execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
307+
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
308308
with patch("botocore.waiter.create_waiter_with_client") as mock_waiter:
309309
mock_waiter.return_value.wait = Mock()
310310
execution.wait(delay=10, max_attempts=5)
@@ -476,22 +476,22 @@ def test_pipeline_list_versions(mock_session, mock_step):
476476

477477

478478
def test_pipeline_execution_result_waiter_error(mock_session):
479-
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
479+
from sagemaker.mlops.workflow.pipeline import PipelineExecution
480480
from botocore.exceptions import WaiterError
481481

482-
execution = _PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session)
482+
execution = PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session)
483483

484484
with patch.object(execution, "wait", side_effect=WaiterError("name", "reason", {})):
485485
with pytest.raises(WaiterError):
486486
execution.result("step1")
487487

488488

489489
def test_pipeline_execution_result_terminal_failure(mock_session):
490-
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
490+
from sagemaker.mlops.workflow.pipeline import PipelineExecution
491491
from botocore.exceptions import WaiterError
492492
from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT
493493

494-
execution = _PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session)
494+
execution = PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session)
495495
mock_session.sagemaker_client.list_pipeline_execution_steps.return_value = {
496496
"PipelineExecutionSteps": [{"StepName": "step1", "Metadata": {"TrainingJob": {"Arn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/job"}}}]
497497
}

0 commit comments

Comments
 (0)