Skip to content

Commit 81158b3

Browse files
committed
test: fail fast on Trainium capacity queueing in neuron slow test
test_gated_model_training_v2_neuron submits a training job on ml.trn1.32xlarge, which can sit in "Pending - waiting for capacity" for hours without ever failing, hanging the slow-tests build until it times out and masking other tests' results. Add fit_estimator_with_capacity_xfail() helper that submits the job non-blocking, polls its secondary status, and if it stays queued for capacity beyond a timeout (default 30m), stops the job and raises a CapacityError so the existing x_fail_if_ice decorator marks the test xfail instead of hanging. Capacity is a transient region-level condition, not an SDK defect, and quota is sufficient (trn1.32xlarge training quota=20, test needs 1). X-AI-Prompt: Implement non-blocking fit with capacity-queue timeout that triggers xfail via x_fail_if_ice for the neuron slow test X-AI-Tool: kiro-cli
1 parent 764fb2e commit 81158b3

2 files changed

Lines changed: 80 additions & 2 deletions

File tree

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
get_sm_session,
2929
get_training_dataset_for_model_and_version,
3030
x_fail_if_ice,
31+
fit_estimator_with_capacity_xfail,
3132
)
3233

3334
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
@@ -221,11 +222,12 @@ def test_gated_model_training_v2_neuron(setup):
221222
)
222223

223224
# uses ml.trn1.32xlarge instance
224-
estimator.fit(
225+
fit_estimator_with_capacity_xfail(
226+
estimator,
225227
{
226228
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
227229
f"{get_training_dataset_for_model_and_version(model_id, '*')}",
228-
}
230+
},
229231
)
230232

231233
# uses ml.inf2.xlarge instance

tests/integ/sagemaker/jumpstart/utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,82 @@ def wrapper(*args, **kwargs):
9292
return wrapper
9393

9494

95+
# Default time (in seconds) we allow a training job to sit in
96+
# "Pending - waiting for capacity" before treating it as a capacity shortage.
97+
# Scarce accelerators (e.g. Trainium ml.trn1.32xlarge) can queue for hours
98+
# without ever failing, which would otherwise exhaust the build timeout.
99+
DEFAULT_CAPACITY_WAIT_TIMEOUT_SECONDS = 30 * 60
100+
_CAPACITY_POLL_INTERVAL_SECONDS = 60
101+
102+
103+
def fit_estimator_with_capacity_xfail(
104+
estimator,
105+
inputs,
106+
capacity_wait_timeout_seconds: int = DEFAULT_CAPACITY_WAIT_TIMEOUT_SECONDS,
107+
):
108+
"""Submit a training job and wait for it, but fail fast on capacity queueing.
109+
110+
Unlike ``estimator.fit()`` (which blocks until the job reaches a terminal
111+
state), this submits the job non-blocking and then polls its secondary
112+
status. Insufficient capacity for scarce instances (e.g. Trainium) does not
113+
surface as an exception; SageMaker simply parks the job in
114+
``Pending`` / "waiting for capacity" indefinitely. If the job stays in that
115+
state longer than ``capacity_wait_timeout_seconds`` before training starts,
116+
we raise a ``RuntimeError`` whose message contains ``"CapacityError"`` so the
117+
``x_fail_if_ice`` decorator marks the test as xfail instead of letting it hang
118+
until the build times out.
119+
120+
Args:
121+
estimator: A SageMaker estimator (e.g. ``JumpStartEstimator``).
122+
inputs: The training inputs passed through to ``estimator.fit()``.
123+
capacity_wait_timeout_seconds: Max time to allow the job to remain in
124+
"waiting for capacity" before declaring a capacity shortage.
125+
"""
126+
estimator.fit(inputs, wait=False)
127+
128+
training_job_name = estimator.latest_training_job.name
129+
sagemaker_client = estimator.sagemaker_session.sagemaker_client
130+
131+
capacity_wait_deadline = time.time() + capacity_wait_timeout_seconds
132+
training_started = False
133+
134+
while True:
135+
desc = sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
136+
status = desc["TrainingJobStatus"]
137+
secondary_status = desc.get("SecondaryStatus", "")
138+
secondary_message = desc.get("SecondaryStatusMessage", "")
139+
140+
# Once the job leaves the pre-capacity phases, capacity has been granted
141+
# and normal training-time limits (MaxRuntimeInSeconds) take over.
142+
if secondary_status not in ("Starting", "Pending"):
143+
training_started = True
144+
145+
if status in ("Completed", "Failed", "Stopped"):
146+
if status != "Completed":
147+
raise RuntimeError(
148+
f"Training job {training_job_name} ended with status {status}: "
149+
f"{desc.get('FailureReason', secondary_message)}"
150+
)
151+
return desc
152+
153+
if not training_started and time.time() > capacity_wait_deadline:
154+
# Stop the queued job so it does not keep holding the request, then
155+
# raise a CapacityError so x_fail_if_ice converts this to an xfail.
156+
try:
157+
sagemaker_client.stop_training_job(TrainingJobName=training_job_name)
158+
except ClientError:
159+
pass
160+
raise RuntimeError(
161+
"CapacityError: training job "
162+
f"{training_job_name} stayed in '{secondary_status}' "
163+
f"({secondary_message!r}) for over "
164+
f"{capacity_wait_timeout_seconds}s without acquiring "
165+
f"{estimator.instance_type} capacity."
166+
)
167+
168+
time.sleep(_CAPACITY_POLL_INTERVAL_SECONDS)
169+
170+
95171
def download_inference_assets():
96172

97173
if not os.path.exists(TMP_DIRECTORY_PATH):

0 commit comments

Comments
 (0)