Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/spot_instance_analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ With ~7% interruption rate and 15-min checkpoints, expect ~10-15% overhead from

## Implementation TODO

- [ ] Add spot instance support to `aws_vm.py` (`create_vm` with `InstanceMarketOptions`)
- [x] Add spot instance support to `aws_vm.py` (`create_vm` with `InstanceMarketOptions`)
- [ ] Add S3 checkpoint upload to training loop
- [ ] Add termination handler (metadata polling + checkpoint trigger)
- [ ] Add g6 instance types to `GPU_INSTANCE_TYPE_FALLBACKS` in `aws_vm.py`
Expand Down
29 changes: 25 additions & 4 deletions openadapt_evals/infrastructure/aws_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def create_vm(
admin_username: str = "ubuntu",
image_id: str | None = None,
gpu: bool = False,
spot: bool = False,
) -> dict[str, Any]:
"""Create an EC2 instance.

Expand All @@ -382,6 +383,10 @@ def create_vm(
image_id: AMI ID. If None, auto-selected based on gpu flag.
gpu: If True, use the Deep Learning AMI with pre-installed
NVIDIA drivers and CUDA. Required for GPU training.
spot: If True, request a spot instance for cost savings.
The instance will use one-time spot pricing and terminate
on interruption. See docs/spot_instance_analysis.md for
pricing details and interruption risk.

Returns:
Dict with at least "publicIpAddress" key.
Expand Down Expand Up @@ -409,8 +414,8 @@ def create_vm(
else:
ami_id = self._find_latest_ubuntu_ami(region)

# Launch instance
instances = ec2_resource.create_instances(
# Build create_instances kwargs
create_kwargs: dict[str, Any] = dict(
ImageId=ami_id,
InstanceType=size,
KeyName=infra["key_name"],
Expand Down Expand Up @@ -445,6 +450,21 @@ def create_vm(
],
)

if spot:
logger.info(
"Requesting spot instance for %s (%s)", name, size
)
create_kwargs["InstanceMarketOptions"] = {
"MarketType": "spot",
"SpotOptions": {
"SpotInstanceType": "one-time",
"InstanceInterruptionBehavior": "terminate",
},
}

# Launch instance
instances = ec2_resource.create_instances(**create_kwargs)

instance = instances[0]
logger.info(f"Instance {name} ({instance.id}) launching...")

Expand All @@ -469,8 +489,9 @@ def create_vm(
)
public_ip = eip["PublicIp"]

logger.info(f"Instance {name} running at {public_ip}")
return {"publicIpAddress": public_ip, "name": name}
spot_label = " (spot)" if spot else ""
logger.info(f"Instance {name}{spot_label} running at {public_ip}")
return {"publicIpAddress": public_ip, "name": name, "spot": spot}

except Exception as e:
raise RuntimeError(f"EC2 instance creation failed: {e}") from e
Expand Down
31 changes: 25 additions & 6 deletions scripts/train_verl_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,16 @@ def _scp_upload(ip: str, local_path: Path, remote_path: str, username: str = "ub
raise RuntimeError(f"SCP failed: {result.stderr}")


def provision_gpu_vm(cloud: str, dry_run: bool = False) -> tuple[str, str, str]:
def provision_gpu_vm(
cloud: str, dry_run: bool = False, spot: bool = False,
) -> tuple[str, str, str]:
"""Provision a GPU VM and return (ip, size, region).

Args:
cloud: Cloud provider ("azure" or "aws").
dry_run: If True, show what would happen without creating anything.
spot: If True, request a spot instance for cost savings (AWS only).

Returns:
Tuple of (public_ip, vm_size, region).
"""
Expand All @@ -98,14 +105,20 @@ def provision_gpu_vm(cloud: str, dry_run: bool = False) -> tuple[str, str, str]:

logger.info("Finding available GPU VM size...")
vm_size, region, cost = vm.find_available_size_and_region(gpu=True)
logger.info("Selected: %s ($%.2f/hr) in %s", vm_size, cost, region)
spot_label = " (spot)" if spot else ""
logger.info("Selected: %s ($%.2f/hr)%s in %s", vm_size, cost, spot_label, region)

if dry_run:
logger.info("[DRY RUN] Would create %s in %s", GPU_VM_NAME, region)
logger.info("[DRY RUN] Would create %s%s in %s", GPU_VM_NAME, spot_label, region)
return ("DRY_RUN_IP", vm_size, region)

logger.info("Creating GPU VM '%s'...", GPU_VM_NAME)
info = vm.create_vm(name=GPU_VM_NAME, region=region, size=vm_size)
logger.info("Creating GPU VM '%s'%s...", GPU_VM_NAME, spot_label)
create_kwargs: dict = dict(name=GPU_VM_NAME, region=region, size=vm_size, gpu=True)
if spot and cloud == "aws":
create_kwargs["spot"] = True
elif spot and cloud != "aws":
logger.warning("--spot is only supported on AWS; ignoring for %s", cloud)
info = vm.create_vm(**create_kwargs)
ip = info.get("publicIpAddress") or vm.get_vm_ip(GPU_VM_NAME)

if not ip:
Expand Down Expand Up @@ -476,6 +489,10 @@ def main():
"--epochs", type=int, default=100,
help="Training epochs (default: 100)",
)
parser.add_argument(
"--spot", action="store_true",
help="Use spot instances for cost savings (AWS only, ~50%% cheaper)",
)
parser.add_argument(
"--setup-only", action="store_true",
help="Only provision and setup, don't start training",
Expand Down Expand Up @@ -511,7 +528,9 @@ def main():
ip = args.gpu_ip
logger.info("Using existing GPU VM: %s", ip)
else:
ip, vm_size, region = provision_gpu_vm(args.cloud, dry_run=args.dry_run)
ip, vm_size, region = provision_gpu_vm(
args.cloud, dry_run=args.dry_run, spot=args.spot,
)
if args.dry_run:
logger.info("[DRY RUN] Would setup and train on %s", vm_size)
return
Expand Down