Skip to content

Commit d0cbf41

Browse files
authored
test: fix canaries-v2 (#5932)
* fix: make JumpStart private hub integ tests xdist-safe Under pytest-xdist (-n 120) each worker created its own private hub, exhausting the per-account hub limit (100) and triggering destructive cross-worker cleanup that deleted hubs other workers were actively using, causing "Hub ... does not exist" failures. The add_model_references fixture also swallowed all errors and did not wait for async reference propagation, causing "Hub content ... does not exist" failures. - Share a single hub across all xdist workers via filelock + a JSON state file with reference counting; only the last worker tears it down. - Make _cleanup_old_hubs non-destructive: only delete hubs older than STALE_HUB_AGE_HOURS and never the active run's hub. - Add add_model_references_to_hub helper that creates references idempotently (keyed by hub + model set) and polls until each reference is resolvable before tests run. * fix: isolate sagemaker_session for serve integ tests to prevent settings pollution ModelBuilder mutates session.settings._local_download_dir to a temporary /tmp/sagemaker/model-builder/<uuid> path. The serve integ tests passed the repo-wide session-scoped sagemaker_session fixture into ModelBuilder, so that mutation leaked across test modules. After the temp dir was cleaned up, the lingering setting broke unrelated tests sharing the same session, notably tests/integ/sagemaker/workflow/test_tuning_steps.py::test_tuning_multi_algos with "ValueError: Inputted directory ... does not exist". Override sagemaker_session in tests/integ/sagemaker/serve/conftest.py with a dedicated session (constructed identically to the parent fixture) so the ModelBuilder mutation stays contained within the serve package. * fix: tear down shared JumpStart hub after all xdist workers finish The previous reference-counted teardown in the session fixture finalizer was unsafe: pytest-xdist distributes tests dynamically, so a worker could finish its session (running finalizers) while other workers still had hub tests pending. Decrementing to zero there deleted the shared hub mid-run, causing "Hub ... does not exist" / "Hub content ... does not exist" failures in gated hub tests. Workers now only create-or-reuse the shared hub (never delete it). Teardown runs exactly once in pytest_sessionfinish on the controller process (no workerinput), which is guaranteed to run after all workers finish. Stale hub reclamation continues to be handled by the age-based _cleanup_old_hubs. * fix: stabilize Spark jar build and inference-component endpoint timeout in integ tests Two unrelated v2 integ-test failures, fixed together: - test_spark_processing.py::test_sagemaker_pyspark_v3 (Spark 3.x): build_jar ran javac/jar without checking exit codes, so a failed jar rebuild (which truncates the committed hello-spark-java.jar) was swallowed and surfaced later as a misleading "code ... wasn't found" error, especially under xdist where the fixture runs per worker. Run the build commands with explicit return-code checks and assert the jar exists afterward. - test_serve_model_builder_inference_component_happy.py:: test_model_builder_ic_sagemaker_endpoint: deploying a 7B JumpStart model as an inference component on ml.g5.24xlarge regularly needs more than the 15-minute standard endpoint timeout to reach InService (the failure was a deploy timeout, not a quota cap). Add a dedicated 30-minute timeout (SERVE_SAGEMAKER_IC_ENDPOINT_TIMEOUT) for this flow without changing the standard serve endpoint timeout. * https://us-west-2.console.aws.amazon.com/cloudwatch/home?region=us-west-2#logsV2:log-groups/log-group/$252Faws$252Fcodebuild$252Fsagemaker-python-sdk-ci-integ-tests/log-events/e558697a-488d-4eab-a4ad-2971d9a1081f * fix: stop deleting shared JumpStart hub mid-run; xfail flaky IC deploy test JumpStart hub: The shared hub was being deleted at session end on the controller, but hub tests deploy long-lived endpoints, so a straggler worker could still be running a hub test at ~100% when teardown deleted the hub, causing intermittent "Hub ... does not exist" failures (e.g. test_jumpstart_hub_gated_estimator_ with_eula). Stop deleting the hub during the run entirely: session-end teardown still cleans leaked endpoints/models/configs/artifacts but no longer deletes the hub, and stale hubs from prior runs are reclaimed proactively at setup via the age-based _cleanup_old_hubs (older than STALE_HUB_AGE_HOURS). Inference-component serve test: test_model_builder_ic_sagemaker_endpoint fails in the ModelBuilder IC deploy path: CreateEndpoint is followed by a DescribeEndpoint that intermittently reports the endpoint as not found. This is an SDK-level issue, not a test config problem, so xfail (non-strict) the test to unblock the canary while it is tracked separately. X-AI-Prompt: Stop mid-run hub deletion (rely on age-based reclamation) and xfail the flaky ModelBuilder inference-component deploy test X-AI-Tool: kiro-cli * test: speed up slow JumpStart estimator canary integ tests These canaries only need to exercise the train/deploy/predict flow, not produce a well-trained model, yet they dominated canary runtime (the estimator tests each ran ~100 min). Trim the training workload to bring the suite under one hour while keeping coverage intact. Bert estimator tests (full QNLI -> QNLI-tiny + epochs=1): - map the floating "*" version of huggingface-spc-bert-base-cased to the QNLI-tiny dataset instead of the full QNLI dataset (constants.py) - cap training to a single epoch (hyperparameters={"epochs": "1"}) for: - test_jumpstart_estimator - test_jumpstart_hub_estimator - test_jumpstart_hub_estimator_with_session Gated llama estimator tests (sec_amazon has no tiny variant, so cap steps via hyperparameters={"max_steps": "1"}): - test_gated_model_training_v1 - test_gated_model_training_v2 - test_jumpstart_hub_gated_estimator_with_eula X-AI-Prompt: Reduce JumpStart estimator canary test runtime by using the tiny training dataset and capping epochs/steps so the suite finishes under an hour X-AI-Tool: kiro-cli * test: mark JumpStart neuron gated training test as slow_test Excludes test_gated_model_training_v2_neuron from ci-integ-tests and canaries-v2, which both filter out `slow_test`. Trn1/Inf2 capacity makes this test prone to multi-hour stalls, and max_steps=1 cannot shrink the provisioning wait.
1 parent 07d5847 commit d0cbf41

11 files changed

Lines changed: 363 additions & 95 deletions

tests/integ/sagemaker/jumpstart/conftest.py

Lines changed: 146 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import json
1516
import os
17+
import pathlib
18+
from datetime import datetime, timedelta, timezone
19+
1620
import boto3
1721
import pytest
22+
from filelock import FileLock
1823
from botocore.config import Config
1924
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
2025
from sagemaker.jumpstart.hub.hub import Hub
@@ -39,19 +44,28 @@
3944
)
4045

4146

42-
def _setup():
47+
# Only delete leftover hubs from previous test runs that are older than this many
48+
# hours. This guards against deleting a hub that another concurrent test run (or
49+
# xdist worker) is actively using.
50+
STALE_HUB_AGE_HOURS = 3
51+
52+
53+
def _setup(test_suite_id=None, test_hub_name=None):
4354
print("Setting up...")
44-
test_suite_id = get_test_suite_id()
45-
test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}"
55+
test_suite_id = test_suite_id or get_test_suite_id()
56+
test_hub_name = test_hub_name or f"{HUB_NAME_PREFIX}{test_suite_id}"
4657
test_hub_description = "PySDK Integ Test Private Hub"
4758

4859
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suite_id})
4960
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name})
5061

5162
# Create a private hub to use for the test session
52-
hub = Hub(
53-
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
54-
)
63+
hub = Hub(hub_name=test_hub_name, sagemaker_session=get_sm_session())
64+
65+
# Proactively reclaim stale hubs from prior runs so we don't accumulate
66+
# toward the per-account private hub limit. This only deletes hubs older
67+
# than STALE_HUB_AGE_HOURS and never the hub we are about to use.
68+
_cleanup_old_hubs(get_sm_session(), active_hub_name=test_hub_name)
5569

5670
# Check if hub already exists before creating
5771
try:
@@ -73,14 +87,14 @@ def _setup():
7387
raise
7488

7589

76-
def _teardown():
90+
def _teardown(test_suite_id=None, test_hub_name=None, delete_hub=False):
7791
print("Tearing down...")
7892

7993
test_cache_bucket = get_test_artifact_bucket()
8094

81-
test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
95+
test_suite_id = test_suite_id or os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
8296

83-
test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
97+
test_hub_name = test_hub_name or os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
8498

8599
boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)
86100

@@ -152,34 +166,49 @@ def _teardown():
152166
bucket = s3_resource.Bucket(test_cache_bucket)
153167
bucket.objects.filter(Prefix=test_suite_id + "/").delete()
154168

155-
# delete private hubs
156-
_delete_hubs(sagemaker_session, test_hub_name)
169+
# delete private hubs (only when explicitly requested). During an xdist run
170+
# we never delete the active hub, because a straggler worker may still be
171+
# running a hub test when another process reaches teardown; stale hubs from
172+
# prior runs are reclaimed by the age-based _cleanup_old_hubs instead.
173+
if delete_hub:
174+
_delete_hubs(sagemaker_session, test_hub_name)
175+
157176

177+
def _cleanup_old_hubs(sagemaker_session, active_hub_name=None):
178+
"""Clean up stale test hubs from previous runs to free up resources.
158179
159-
def _cleanup_old_hubs(sagemaker_session):
160-
"""Clean up old test hubs to free up resources."""
180+
Only deletes hubs that are clearly stale (older than ``STALE_HUB_AGE_HOURS``)
181+
so that hubs actively in use by the current test run or by concurrent xdist
182+
workers are never removed. The hub for the current run (``active_hub_name``)
183+
is always preserved.
184+
"""
161185
try:
186+
active_hub_name = active_hub_name or os.environ.get(ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME)
187+
cutoff = datetime.now(timezone.utc) - timedelta(hours=STALE_HUB_AGE_HOURS)
188+
162189
response = sagemaker_session.list_hubs()
163-
test_hubs = [
164-
hub
165-
for hub in response.get("HubSummaries", [])
166-
if hub["HubName"].startswith(HUB_NAME_PREFIX)
167-
]
168-
169-
# Sort by creation time and delete oldest hubs
170-
test_hubs.sort(key=lambda x: x.get("CreationTime", ""))
171-
172-
# Delete oldest hubs (keep only the most recent 10)
173-
hubs_to_delete = (
174-
test_hubs[:-10] if len(test_hubs) > 10 else test_hubs[: max(0, len(test_hubs) - 40)]
175-
)
190+
for hub in response.get("HubSummaries", []):
191+
hub_name = hub["HubName"]
192+
if not hub_name.startswith(HUB_NAME_PREFIX):
193+
continue
194+
if hub_name == active_hub_name:
195+
continue
196+
197+
creation_time = hub.get("CreationTime")
198+
# Only delete hubs we can confirm are older than the cutoff. If the
199+
# creation time is unavailable, err on the side of keeping the hub.
200+
if creation_time is None:
201+
continue
202+
if creation_time.tzinfo is None:
203+
creation_time = creation_time.replace(tzinfo=timezone.utc)
204+
if creation_time >= cutoff:
205+
continue
176206

177-
for hub in hubs_to_delete:
178207
try:
179-
print(f"Deleting old hub: {hub['HubName']}")
180-
_delete_hubs(sagemaker_session, hub["HubName"])
208+
print(f"Deleting stale hub: {hub_name}")
209+
_delete_hubs(sagemaker_session, hub_name)
181210
except Exception as e:
182-
print(f"Failed to delete hub {hub['HubName']}: {e}")
211+
print(f"Failed to delete hub {hub_name}: {e}")
183212
except Exception as e:
184213
print(f"Failed to cleanup old hubs: {e}")
185214

@@ -210,8 +239,92 @@ def _delete_hub_contents(sagemaker_session, hub_name, model):
210239
)
211240

212241

242+
def _hub_state_root(config):
243+
"""Return the run-level tmp dir shared by the xdist controller and workers.
244+
245+
The controller's basetemp is the run root (e.g. ``.../pytest-N``) while each
246+
worker's basetemp is a ``popen-gw*`` subdir of it. Normalizing to the run
247+
root gives every process the same location for the shared state file.
248+
249+
Works across pytest versions: prefers the ``TempPathFactory`` attached as
250+
``config._tmp_path_factory`` and falls back to the legacy ``_tmpdirhandler``.
251+
"""
252+
factory = getattr(config, "_tmp_path_factory", None)
253+
if factory is not None:
254+
basetemp = pathlib.Path(str(factory.getbasetemp()))
255+
else:
256+
basetemp = pathlib.Path(str(config._tmpdirhandler.getbasetemp()))
257+
258+
if basetemp.name.startswith("popen-gw"):
259+
return basetemp.parent
260+
return basetemp
261+
262+
213263
@pytest.fixture(scope="session", autouse=True)
214264
def setup(request):
215-
_setup()
216-
217-
request.addfinalizer(_teardown)
265+
"""Ensure a single shared private hub exists for the whole test run.
266+
267+
Under pytest-xdist every worker is a separate process, so a naive
268+
``scope="session"`` fixture would create one hub per worker. With high
269+
parallelism (e.g. ``-n 120``) that quickly exhausts the per-account private
270+
hub limit (100). All workers therefore coordinate through a lock file and a
271+
shared JSON state file: the first worker creates the hub, the rest reuse it.
272+
273+
The hub is intentionally NOT deleted at the end of the run. xdist
274+
distributes tests dynamically and hub tests deploy long-lived endpoints, so
275+
a straggler worker can still be running a hub test (at ~100%) while another
276+
process reaches teardown. Deleting the hub there pulls it out from under the
277+
straggler ("Hub ... does not exist" failures). Instead, leaked endpoints and
278+
artifacts are cleaned at run end, and the hub itself is reclaimed on a later
279+
run by the age-based ``_cleanup_old_hubs`` (older than STALE_HUB_AGE_HOURS).
280+
"""
281+
root_tmp_dir = _hub_state_root(request.config)
282+
state_file = root_tmp_dir / "jumpstart_hub_state.json"
283+
lock_file = root_tmp_dir / "jumpstart_hub_state.json.lock"
284+
285+
with FileLock(str(lock_file)):
286+
if state_file.is_file():
287+
state = json.loads(state_file.read_text())
288+
else:
289+
test_suite_id = get_test_suite_id()
290+
test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}"
291+
_setup(test_suite_id=test_suite_id, test_hub_name=test_hub_name)
292+
state = {
293+
"test_suite_id": test_suite_id,
294+
"test_hub_name": test_hub_name,
295+
}
296+
state_file.write_text(json.dumps(state))
297+
298+
# Ensure this worker's environment points at the shared hub.
299+
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: state["test_suite_id"]})
300+
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: state["test_hub_name"]})
301+
302+
303+
def pytest_sessionfinish(session, exitstatus):
304+
"""Clean up leaked test resources once, after all xdist workers finish.
305+
306+
Runs only on the controller (xdist workers carry a ``workerinput`` attribute
307+
on their config; a non-xdist run has none). Deletes endpoints/models/configs
308+
and S3 artifacts tagged for this run, but deliberately does NOT delete the
309+
shared hub (see ``setup``); stale hubs are reclaimed by ``_cleanup_old_hubs``
310+
on a subsequent run.
311+
"""
312+
if hasattr(session.config, "workerinput"):
313+
return # xdist worker: the controller handles cleanup.
314+
315+
root_tmp_dir = _hub_state_root(session.config)
316+
state_file = root_tmp_dir / "jumpstart_hub_state.json"
317+
lock_file = root_tmp_dir / "jumpstart_hub_state.json.lock"
318+
319+
with FileLock(str(lock_file)):
320+
if not state_file.is_file():
321+
return
322+
state = json.loads(state_file.read_text())
323+
try:
324+
_teardown(
325+
test_suite_id=state["test_suite_id"],
326+
test_hub_name=state["test_hub_name"],
327+
delete_hub=False,
328+
)
329+
finally:
330+
state_file.unlink()

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4747
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
4848
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
4949
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
50-
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
50+
# Use the tiny dataset for the floating "*" version too: these are canary
51+
# tests that only need to exercise the train/deploy flow, not produce a
52+
# well-trained model. The full QNLI dataset made fit() dramatically slower.
53+
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
5154
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
5255
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
5356
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def test_jumpstart_estimator(setup):
6161
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
6262
max_run=259200, # avoid exceeding resource limits
6363
instance_type="ml.g4dn.xlarge",
64+
# Canary only needs to exercise the train/deploy flow, so cap training
65+
# to a single epoch to keep fit() fast.
66+
hyperparameters={"epochs": "1"},
6467
)
6568

6669
# uses ml.g4dn.xlarge instance
@@ -111,6 +114,9 @@ def test_gated_model_training_v1(setup):
111114
environment={"accept_eula": "true"},
112115
max_run=259200, # avoid exceeding resource limits
113116
tolerate_vulnerable_model=True,
117+
# Canary only verifies the train/deploy flow, so cap training to a
118+
# single step to keep fit() fast (sec_amazon has no tiny variant).
119+
hyperparameters={"max_steps": "1"},
114120
)
115121

116122
# uses ml.g5.12xlarge instance
@@ -153,6 +159,9 @@ def test_gated_model_training_v2(setup):
153159
environment={"accept_eula": "true"},
154160
max_run=259200, # avoid exceeding resource limits
155161
tolerate_vulnerable_model=True, # tolerate old version of model
162+
# Canary only verifies the train/deploy flow, so cap training to a
163+
# single step to keep fit() fast (sec_amazon has no tiny variant).
164+
hyperparameters={"max_steps": "1"},
156165
)
157166

158167
# uses ml.g5.12xlarge instance
@@ -190,6 +199,7 @@ def test_gated_model_training_v2(setup):
190199

191200

192201
@x_fail_if_ice
202+
@pytest.mark.slow_test
193203
@pytest.mark.skipif(
194204
tests.integ.test_region() not in TRN2_SUPPORTED_REGIONS,
195205
reason=f"TRN2 instances unavailable in {tests.integ.test_region()}.",

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
download_inference_assets,
3535
get_sm_session,
3636
get_tabular_data,
37+
x_fail_if_ice,
3738
)
3839

3940
INF2_SUPPORTED_REGIONS = {
@@ -192,6 +193,7 @@ def test_jumpstart_gated_model(setup):
192193
assert response is not None
193194

194195

196+
@x_fail_if_ice
195197
def test_jumpstart_gated_model_inference_component_enabled(setup):
196198

197199
model_id = "meta-textgeneration-llama-2-7b"

tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import pytest
1919
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
20-
from sagemaker.jumpstart.hub.hub import Hub
2120

2221
from sagemaker.jumpstart.estimator import JumpStartEstimator
2322
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
@@ -28,10 +27,9 @@
2827
JUMPSTART_TAG,
2928
)
3029
from tests.integ.sagemaker.jumpstart.utils import (
31-
get_public_hub_model_arn,
3230
get_sm_session,
33-
with_exponential_backoff,
3431
get_training_dataset_for_model_and_version,
32+
add_model_references_to_hub,
3533
)
3634

3735
MAX_INIT_TIME_SECONDS = 5
@@ -43,23 +41,13 @@
4341
}
4442

4543

46-
@with_exponential_backoff()
47-
def create_model_reference(hub_instance, model_arn):
48-
try:
49-
hub_instance.create_model_reference(model_arn=model_arn)
50-
except Exception:
51-
pass
52-
53-
5444
@pytest.fixture(scope="session")
5545
def add_model_references():
56-
# Create Model References to test in Hub
57-
hub_instance = Hub(
58-
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
46+
# Create Model References to test in Hub (idempotent + waits for readiness)
47+
add_model_references_to_hub(
48+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
49+
model_ids=TEST_MODEL_IDS,
5950
)
60-
for model in TEST_MODEL_IDS:
61-
model_arn = get_public_hub_model_arn(hub_instance, model)
62-
create_model_reference(hub_instance, model_arn)
6351

6452

6553
def test_jumpstart_hub_estimator(setup, add_model_references):
@@ -70,6 +58,9 @@ def test_jumpstart_hub_estimator(setup, add_model_references):
7058
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
7159
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
7260
instance_type="ml.g4dn.xlarge",
61+
# Canary only needs to exercise the train/deploy flow, so cap training
62+
# to a single epoch to keep fit() fast.
63+
hyperparameters={"epochs": "1"},
7364
)
7465

7566
estimator.fit(
@@ -110,6 +101,9 @@ def test_jumpstart_hub_estimator_with_session(setup, add_model_references):
110101
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
111102
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
112103
instance_type="ml.g4dn.xlarge",
104+
# Canary only needs to exercise the train/deploy flow, so cap training
105+
# to a single epoch to keep fit() fast.
106+
hyperparameters={"epochs": "1"},
113107
)
114108

115109
estimator.fit(
@@ -149,6 +143,9 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
149143
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
150144
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
151145
instance_type="ml.g5.2xlarge",
146+
# Canary only verifies the train/deploy flow, so cap training to a
147+
# single step to keep fit() fast (sec_amazon has no tiny variant).
148+
hyperparameters={"max_steps": "1"},
152149
)
153150

154151
estimator.fit(

0 commit comments

Comments
 (0)