Skip to content

Commit c07247e

Browse files
feat: MLflow metrics visualization, enhanced wait UI, and eval job links (#5662)
* Intermediary checkpoint * Evaluation job update * Fix studio domain mismatch for url, update text color, add link of evaluation job * Add underscore to fine-tune and eval job links * Update link to console, conditionally display studio link, update link color to blue * Always show console link, conditional show studio link * Minor update to execution link names * Fix region issue for studio url * Revert notebook change to original * Address PR readiness * Fix sagemaker-train unit tst * Update resources_codegen based on sagemaker-core change
1 parent 2e95bc0 commit c07247e

File tree

17 files changed

+739
-69
lines changed

17 files changed

+739
-69
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ env/
4040
sagemaker_train/src/**/container_drivers/sm_train.sh
4141
sagemaker_train/src/**/container_drivers/sourcecode.json
4242
sagemaker_train/src/**/container_drivers/distributed.json
43+
.kiro

sagemaker-core/src/sagemaker/core/resources.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35788,7 +35788,7 @@ def stop(self) -> None:
3578835788
ResourceNotFound: Resource being access is not found.
3578935789
"""
3579035790

35791-
client = SageMakerClient().client
35791+
client = SageMakerClient().sagemaker_client
3579235792

3579335793
operation_input_args = {
3579435794
"TrainingJobName": self.training_job_name,
@@ -35833,15 +35833,17 @@ def wait(
3583335833
progress.add_task("Waiting for TrainingJob...")
3583435834
status = Status("Current status:")
3583535835

35836-
instance_count = (
35837-
sum(
35838-
instance_group.instance_count
35839-
for instance_group in self.resource_config.instance_groups
35840-
)
35841-
if self.resource_config.instance_groups
35842-
and not isinstance(self.resource_config.instance_groups, Unassigned)
35843-
else self.resource_config.instance_count
35844-
)
35836+
instance_count = 1 # Default
35837+
if not isinstance(self.resource_config, Unassigned):
35838+
if (hasattr(self.resource_config, 'instance_groups') and
35839+
self.resource_config.instance_groups and
35840+
not isinstance(self.resource_config.instance_groups, Unassigned)):
35841+
instance_count = sum(
35842+
instance_group.instance_count
35843+
for instance_group in self.resource_config.instance_groups
35844+
)
35845+
elif hasattr(self.resource_config, 'instance_count'):
35846+
instance_count = self.resource_config.instance_count
3584535847

3584635848
if logs:
3584735849
multi_stream_logger = MultiLogStreamHandler(

sagemaker-core/src/sagemaker/core/tools/resources_codegen.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,12 +1664,19 @@ def _get_instance_count_ref(self, resource_name: str) -> str:
16641664
"""
16651665

16661666
if resource_name == "TrainingJob":
1667-
return """(
1668-
sum(instance_group.instance_count for instance_group in self.resource_config.instance_groups)
1669-
if self.resource_config.instance_groups and not isinstance(self.resource_config.instance_groups, Unassigned)
1670-
else self.resource_config.instance_count
1671-
)
1672-
"""
1667+
return """1 # Default
1668+
if not isinstance(self.resource_config, Unassigned):
1669+
if (
1670+
hasattr(self.resource_config, "instance_groups")
1671+
and self.resource_config.instance_groups
1672+
and not isinstance(self.resource_config.instance_groups, Unassigned)
1673+
):
1674+
instance_count = sum(
1675+
instance_group.instance_count
1676+
for instance_group in self.resource_config.instance_groups
1677+
)
1678+
elif hasattr(self.resource_config, "instance_count"):
1679+
instance_count = self.resource_config.instance_count"""
16731680
elif resource_name == "TransformJob":
16741681
return "self.transform_resources.instance_count"
16751682
elif resource_name == "ProcessingJob":

sagemaker-core/src/sagemaker/core/tools/templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def delete(
483483
@Base.add_validate_call
484484
def stop(self) -> None:
485485
{docstring}
486-
client = SageMakerClient().client
486+
client = SageMakerClient().sagemaker_client
487487
488488
operation_input_args = {{
489489
{operation_input_args}

sagemaker-train/example_notebooks/evaluate/benchmark_demo.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@
442442
],
443443
"metadata": {
444444
"kernelspec": {
445-
"display_name": "Python 3",
445+
"display_name": "py3.10.14",
446446
"language": "python",
447447
"name": "python3"
448448
},
@@ -456,7 +456,7 @@
456456
"name": "python",
457457
"nbconvert_exporter": "python",
458458
"pygments_lexer": "ipython3",
459-
"version": "3.12.12"
459+
"version": "3.10.14"
460460
}
461461
},
462462
"nbformat": 4,

sagemaker-train/pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ test = [
6161
"graphene",
6262
"IPython"
6363
]
64+
notebook = [
65+
"ipywidgets>=8.0.0",
66+
"rich>=13.0.0",
67+
"matplotlib>=3.5.0",
68+
]
6469

6570
[tool.setuptools.packages.find]
6671
where = ["src/"]

sagemaker-train/src/sagemaker/train/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,16 @@ def __getattr__(name):
5656
elif name == "get_builtin_metrics":
5757
from sagemaker.train.evaluate import get_builtin_metrics
5858
return get_builtin_metrics
59+
elif name == "plot_training_metrics":
60+
from sagemaker.train.common_utils.metrics_visualizer import plot_training_metrics
61+
return plot_training_metrics
62+
elif name == "get_available_metrics":
63+
from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics
64+
return get_available_metrics
65+
elif name == "get_studio_url":
66+
from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
67+
return get_studio_url
68+
elif name == "get_mlflow_url":
69+
from sagemaker.train.common_utils.trainer_wait import get_mlflow_url
70+
return get_mlflow_url
5971
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

sagemaker-train/src/sagemaker/train/common_utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class _MLflowConstants:
2020

2121
# Metric names
2222
TOTAL_LOSS_METRIC = 'total_loss'
23+
LOSS_METRIC_KEYWORDS = ('loss',)
2324
EPOCH_KEYWORD = 'epoch'
2425

2526
# MLflow run tags

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
376376

377377
except Exception as e:
378378
logger.error("Exception getting fine-tuning options: %s", e)
379+
raise
379380

380381

381382
def _create_input_channels(dataset: str, content_type: Optional[str] = None,

0 commit comments

Comments
 (0)