Skip to content

Commit 93fc1c6

Browse files
abrichrclaude
andcommitted
feat: add spot instance support to AWS VM creation
Add spot=True parameter to AWSVMManager.create_vm() which sets InstanceMarketOptions for one-time spot pricing with terminate-on- interruption behavior. Wire --spot flag through train_verl_e2e.py CLI. Saves ~50% on GPU training costs (e.g. g5.xlarge $0.43/hr vs $1.006/hr). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent dea0b27 commit 93fc1c6

3 files changed

Lines changed: 51 additions & 11 deletions

File tree

docs/spot_instance_analysis.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ With ~7% interruption rate and 15-min checkpoints, expect ~10-15% overhead from
6767

6868
## Implementation TODO
6969

70-
- [ ] Add spot instance support to `aws_vm.py` (`create_vm` with `InstanceMarketOptions`)
70+
- [x] Add spot instance support to `aws_vm.py` (`create_vm` with `InstanceMarketOptions`)
7171
- [ ] Add S3 checkpoint upload to training loop
7272
- [ ] Add termination handler (metadata polling + checkpoint trigger)
7373
- [ ] Add g6 instance types to `GPU_INSTANCE_TYPE_FALLBACKS` in `aws_vm.py`

openadapt_evals/infrastructure/aws_vm.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def create_vm(
370370
admin_username: str = "ubuntu",
371371
image_id: str | None = None,
372372
gpu: bool = False,
373+
spot: bool = False,
373374
) -> dict[str, Any]:
374375
"""Create an EC2 instance.
375376
@@ -382,6 +383,10 @@ def create_vm(
382383
image_id: AMI ID. If None, auto-selected based on gpu flag.
383384
gpu: If True, use the Deep Learning AMI with pre-installed
384385
NVIDIA drivers and CUDA. Required for GPU training.
386+
spot: If True, request a spot instance for cost savings.
387+
The instance will use one-time spot pricing and terminate
388+
on interruption. See docs/spot_instance_analysis.md for
389+
pricing details and interruption risk.
385390
386391
Returns:
387392
Dict with at least "publicIpAddress" key.
@@ -409,8 +414,8 @@ def create_vm(
409414
else:
410415
ami_id = self._find_latest_ubuntu_ami(region)
411416

412-
# Launch instance
413-
instances = ec2_resource.create_instances(
417+
# Build create_instances kwargs
418+
create_kwargs: dict[str, Any] = dict(
414419
ImageId=ami_id,
415420
InstanceType=size,
416421
KeyName=infra["key_name"],
@@ -445,6 +450,21 @@ def create_vm(
445450
],
446451
)
447452

453+
if spot:
454+
logger.info(
455+
"Requesting spot instance for %s (%s)", name, size
456+
)
457+
create_kwargs["InstanceMarketOptions"] = {
458+
"MarketType": "spot",
459+
"SpotOptions": {
460+
"SpotInstanceType": "one-time",
461+
"InstanceInterruptionBehavior": "terminate",
462+
},
463+
}
464+
465+
# Launch instance
466+
instances = ec2_resource.create_instances(**create_kwargs)
467+
448468
instance = instances[0]
449469
logger.info(f"Instance {name} ({instance.id}) launching...")
450470

@@ -469,8 +489,9 @@ def create_vm(
469489
)
470490
public_ip = eip["PublicIp"]
471491

472-
logger.info(f"Instance {name} running at {public_ip}")
473-
return {"publicIpAddress": public_ip, "name": name}
492+
spot_label = " (spot)" if spot else ""
493+
logger.info(f"Instance {name}{spot_label} running at {public_ip}")
494+
return {"publicIpAddress": public_ip, "name": name, "spot": spot}
474495

475496
except Exception as e:
476497
raise RuntimeError(f"EC2 instance creation failed: {e}") from e

scripts/train_verl_e2e.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,16 @@ def _scp_upload(ip: str, local_path: Path, remote_path: str, username: str = "ub
8787
raise RuntimeError(f"SCP failed: {result.stderr}")
8888

8989

90-
def provision_gpu_vm(cloud: str, dry_run: bool = False) -> tuple[str, str, str]:
90+
def provision_gpu_vm(
91+
cloud: str, dry_run: bool = False, spot: bool = False,
92+
) -> tuple[str, str, str]:
9193
"""Provision a GPU VM and return (ip, size, region).
9294
95+
Args:
96+
cloud: Cloud provider ("azure" or "aws").
97+
dry_run: If True, show what would happen without creating anything.
98+
spot: If True, request a spot instance for cost savings (AWS only).
99+
93100
Returns:
94101
Tuple of (public_ip, vm_size, region).
95102
"""
@@ -98,14 +105,20 @@ def provision_gpu_vm(cloud: str, dry_run: bool = False) -> tuple[str, str, str]:
98105

99106
logger.info("Finding available GPU VM size...")
100107
vm_size, region, cost = vm.find_available_size_and_region(gpu=True)
101-
logger.info("Selected: %s ($%.2f/hr) in %s", vm_size, cost, region)
108+
spot_label = " (spot)" if spot else ""
109+
logger.info("Selected: %s ($%.2f/hr)%s in %s", vm_size, cost, spot_label, region)
102110

103111
if dry_run:
104-
logger.info("[DRY RUN] Would create %s in %s", GPU_VM_NAME, region)
112+
logger.info("[DRY RUN] Would create %s%s in %s", GPU_VM_NAME, spot_label, region)
105113
return ("DRY_RUN_IP", vm_size, region)
106114

107-
logger.info("Creating GPU VM '%s'...", GPU_VM_NAME)
108-
info = vm.create_vm(name=GPU_VM_NAME, region=region, size=vm_size)
115+
logger.info("Creating GPU VM '%s'%s...", GPU_VM_NAME, spot_label)
116+
create_kwargs: dict = dict(name=GPU_VM_NAME, region=region, size=vm_size, gpu=True)
117+
if spot and cloud == "aws":
118+
create_kwargs["spot"] = True
119+
elif spot and cloud != "aws":
120+
logger.warning("--spot is only supported on AWS; ignoring for %s", cloud)
121+
info = vm.create_vm(**create_kwargs)
109122
ip = info.get("publicIpAddress") or vm.get_vm_ip(GPU_VM_NAME)
110123

111124
if not ip:
@@ -476,6 +489,10 @@ def main():
476489
"--epochs", type=int, default=100,
477490
help="Training epochs (default: 100)",
478491
)
492+
parser.add_argument(
493+
"--spot", action="store_true",
494+
help="Use spot instances for cost savings (AWS only, ~50%% cheaper)",
495+
)
479496
parser.add_argument(
480497
"--setup-only", action="store_true",
481498
help="Only provision and setup, don't start training",
@@ -511,7 +528,9 @@ def main():
511528
ip = args.gpu_ip
512529
logger.info("Using existing GPU VM: %s", ip)
513530
else:
514-
ip, vm_size, region = provision_gpu_vm(args.cloud, dry_run=args.dry_run)
531+
ip, vm_size, region = provision_gpu_vm(
532+
args.cloud, dry_run=args.dry_run, spot=args.spot,
533+
)
515534
if args.dry_run:
516535
logger.info("[DRY RUN] Would setup and train on %s", vm_size)
517536
return

0 commit comments

Comments
 (0)