forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmetrics_visualizer.py
More file actions
288 lines (233 loc) · 10.6 KB
/
metrics_visualizer.py
File metadata and controls
288 lines (233 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""MLflow metrics visualization utilities for SageMaker training jobs."""
import logging
from typing import Optional, List, Dict, Any
from sagemaker.core.resources import TrainingJob
logger = logging.getLogger(__name__)
def _is_in_studio() -> bool:
"""Check if running inside SageMaker Studio."""
from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata
return _read_domain_id_from_metadata() is not None
def _get_studio_base_url(region: str) -> str:
"""Get Studio base URL, or empty string if domain not resolvable."""
from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata
domain_id = _read_domain_id_from_metadata()
if not domain_id or not region:
return ""
return f"https://studio-{domain_id}.studio.{region}.sagemaker.aws"
def _parse_job_arn(job_arn: str):
"""Parse a SageMaker job ARN into (region, resource) or None."""
import re
m = re.match(r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:(\S+)', job_arn)
return (m.group(1), m.group(2)) if m else None
def get_console_job_url(job_arn: str) -> str:
"""Get AWS Console URL for a SageMaker job ARN.
Args:
job_arn: Full ARN like arn:aws:sagemaker:us-east-1:123:training-job/my-job
Returns:
Console URL or empty string.
"""
parsed = _parse_job_arn(job_arn)
if not parsed:
return ""
region, resource = parsed
job_type_map = {
"training-job/": "#/jobs/",
"processing-job/": "#/processing-jobs/",
"transform-job/": "#/transform-jobs/",
}
for prefix, fragment in job_type_map.items():
if resource.startswith(prefix):
job_name = resource.split("/", 1)[1]
return f"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}{fragment}{job_name}"
return ""
def get_cloudwatch_logs_url(job_arn: str) -> str:
"""Get CloudWatch Logs console URL for a SageMaker job ARN.
Returns:
CloudWatch console URL or empty string.
"""
parsed = _parse_job_arn(job_arn)
if not parsed:
return ""
region, resource = parsed
log_group_map = {
"training-job/": "/aws/sagemaker/TrainingJobs",
"processing-job/": "/aws/sagemaker/ProcessingJobs",
"transform-job/": "/aws/sagemaker/TransformJobs",
}
for prefix, log_group in log_group_map.items():
if resource.startswith(prefix):
job_name = resource.split("/", 1)[1]
encoded_group = log_group.replace("/", "$252F")
return (
f"https://{region}.console.aws.amazon.com/cloudwatch/home?region={region}"
f"#logsV2:log-groups/log-group/{encoded_group}"
f"$3FlogStreamNameFilter$3D{job_name}"
)
return ""
def get_studio_url(training_job, domain_id: str = None) -> str:
"""Get SageMaker Studio URL for training job logs.
Args:
training_job: SageMaker TrainingJob object, job name string, or job ARN string
domain_id: Studio domain ID (e.g., 'd-xxxxxxxxxxxx'). If not provided, attempts to auto-detect
Returns:
Studio URL pointing to the training job details, or empty string if not resolvable
Example:
>>> from sagemaker.train import get_studio_url
>>> url = get_studio_url('my-training-job')
>>> url = get_studio_url('arn:aws:sagemaker:us-west-2:123456789:training-job/my-job')
"""
import re
if isinstance(training_job, str):
arn_match = re.match(
r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:training-job/(.+)',
training_job,
)
if arn_match:
region = arn_match.group(1)
job_name = arn_match.group(2)
else:
# Plain job name — use session region
training_job = TrainingJob.get(training_job_name=training_job)
from sagemaker.core.utils.utils import SageMakerClient
region = SageMakerClient().region_name
job_name = training_job.training_job_name
else:
from sagemaker.core.utils.utils import SageMakerClient
region = SageMakerClient().region_name
job_name = training_job.training_job_name
base = _get_studio_base_url(region)
if not base:
return ""
return f"{base}/jobs/train/{job_name}"
def display_job_links_html(rows: list, as_html: bool = False):
"""Render job/resource links with copy-to-clipboard buttons as a Jupyter HTML table.
Args:
rows: List of dicts, each with keys:
- label (str): Row label (e.g. step name, "Training Job", "MLflow Experiment")
- arn (str): The ARN or URI to display and copy
- url (Optional[str]): Clickable link URL. If None, resolved via get_studio_url for job ARNs.
- url_text (Optional[str]): Link display text. Defaults to "🔗 link"
- url_hint (Optional[str]): Hint text after link. Defaults to "(please sign in to Studio first)"
as_html: If True, return HTML object instead of displaying it.
Returns:
HTML object if as_html=True, otherwise None.
"""
from IPython.display import display, HTML
import html as html_mod
html_rows = ""
for row in rows:
escaped_arn = html_mod.escape(row['arn'])
escaped_label = html_mod.escape(row['label'])
url = row.get('url')
if url is None:
url = get_studio_url(row['arn'])
url_text = row.get('url_text', '🔗 link')
url_hint = row.get('url_hint', '(please sign in to Studio first)')
link_html = ""
if url:
link_html = (
f'<a href="{html_mod.escape(url)}" target="_blank" '
f'style="color:var(--jp-brand-color1,#4a90d9);text-decoration:none;">{html_mod.escape(url_text)}</a>'
f' <span style="color:var(--jp-ui-font-color2,#888);font-size:11px;">{html_mod.escape(url_hint)}</span>'
)
copy_btn = (
f'<button onclick="navigator.clipboard.writeText(\'{escaped_arn}\')'
f'.then(()=>{{this.textContent=\'✓\';setTimeout(()=>this.textContent=\'📋\',1500)}})"'
f' style="border:1px solid var(--jp-border-color1,#555);'
f'background:var(--jp-layout-color2,#333);color:var(--jp-ui-font-color0,white);'
f'border-radius:3px;cursor:pointer;font-size:11px;padding:1px 5px;"'
f' title="Copy">📋</button>'
)
html_rows += (
f'<tr>'
f'<td style="padding:4px 8px;text-align:left;font-weight:bold;color:var(--jp-brand-color1,#4fc3f7);">{escaped_label}</td>'
f'<td style="padding:4px 8px;text-align:left;">{link_html}</td>'
f'<td style="padding:4px 8px;text-align:left;">'
f'<code style="font-size:12px;word-break:break-all;">{escaped_arn}</code>'
f' {copy_btn}</td>'
f'</tr>'
)
result = HTML(
f'<table style="border-collapse:collapse;margin:4px 0;color:var(--jp-ui-font-color0,inherit);">'
f'<tr style="border-bottom:1px solid var(--jp-border-color1,#555);">'
f'<th style="padding:4px 8px;text-align:left;color:var(--jp-brand-color2,#ce93d8);">Step</th>'
f'<th style="padding:4px 8px;text-align:left;color:var(--jp-brand-color2,#ce93d8);">Job Link</th>'
f'<th style="padding:4px 8px;text-align:left;color:var(--jp-brand-color2,#ce93d8);">Job ARN</th>'
f'</tr>{html_rows}</table>'
)
if as_html:
return result
display(result)
def plot_training_metrics(
training_job: TrainingJob,
metrics: Optional[List[str]] = None,
figsize: tuple = (12, 6)
) -> None:
"""Plot training metrics from MLflow for a completed training job.
Args:
training_job: SageMaker TrainingJob object or job name string
metrics: List of metric names to plot. If None, plots all available metrics.
figsize: Figure size as (width, height)
"""
import matplotlib.pyplot as plt
import mlflow
from mlflow.tracking import MlflowClient
from IPython.display import display
import logging
logging.getLogger('botocore.credentials').setLevel(logging.WARNING)
if isinstance(training_job, str):
training_job = TrainingJob.get(training_job_name=training_job)
run_id = training_job.mlflow_details.mlflow_run_id
mlflow.set_tracking_uri(training_job.mlflow_config.mlflow_resource_arn)
client = MlflowClient()
run = mlflow.get_run(run_id)
available_metrics = list(run.data.metrics.keys())
metrics_to_plot = metrics if metrics else available_metrics
# Fetch metric histories
metric_data = {}
for metric_name in metrics_to_plot:
history = client.get_metric_history(run_id, metric_name)
if history:
metric_data[metric_name] = history
# Plot
num_metrics = len(metric_data)
rows = (num_metrics + 1) // 2
fig, axes = plt.subplots(rows, 2, figsize=(figsize[0], figsize[1] * rows))
axes = axes.flatten() if num_metrics > 1 else [axes]
for idx, (metric_name, history) in enumerate(metric_data.items()):
steps = [h.step for h in history]
values = [h.value for h in history]
axes[idx].plot(steps, values, linewidth=2, marker='o', markersize=4)
axes[idx].set_xlabel('Step')
axes[idx].set_ylabel('Value')
axes[idx].set_title(metric_name, fontweight='bold')
axes[idx].grid(True, alpha=0.3)
for idx in range(len(metric_data), len(axes)):
axes[idx].set_visible(False)
plt.suptitle(f'Training Metrics: {training_job.training_job_name}', fontweight='bold', fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.98]) # Leave small space for suptitle
display(fig)
plt.close()
def get_available_metrics(training_job: TrainingJob) -> List[str]:
"""Get list of available metrics for a training job.
Args:
training_job: SageMaker TrainingJob object or job name string
Returns:
List of metric names
"""
try:
import mlflow
except ImportError:
logger.error("mlflow package not installed")
return []
# Handle string input
if isinstance(training_job, str):
training_job = TrainingJob.get(training_job_name=training_job)
if not hasattr(training_job, 'mlflow_config') or not training_job.mlflow_config:
return []
mlflow_details = training_job.mlflow_details
if not mlflow_details or not mlflow_details.mlflow_run_id:
return []
mlflow.set_tracking_uri(training_job.mlflow_config.mlflow_resource_arn)
run = mlflow.get_run(mlflow_details.mlflow_run_id)
return list(run.data.metrics.keys())