Skip to content

Commit 7a27d60

Browse files
abrichrclaude
andauthored
feat: add multi-cloud VM support with AWS backend and VMProvider protocol (#66)
* feat: add multi-cloud VM support with AWS backend and VMProvider protocol - Create VMProvider Protocol (typing.Protocol) for cloud-agnostic VM management - Create AWSVMManager with boto3 for EC2 lifecycle (create, delete, start, stop) - Add resource_scope/ssh_username properties to AzureVMManager - Add list_pool_resources/cleanup_pool_resources to AzureVMManager - Parameterize pool.py SSH calls and scripts with username/home_dir - Add --cloud flag (azure|aws) to all pool CLI commands - Add cloud_provider/aws_region to config.py settings - Add boto3 optional dependency (openadapt-evals[aws]) - Update tests for WAA_START_SCRIPT_TEMPLATE rename Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address review findings in AWS VM backend - Fix DOCKER_SETUP_SCRIPT_WITH_ACR daemon.json double-brace corruption that produced invalid JSON ({{"data-root"...}}) breaking Docker start - Use .metal instance types for AWS (KVM/nested virt required for QEMU) - Fix region mismatch: update self.region and invalidate cached clients when create_vm uses a different region than the manager default - Fix hardcoded "azureuser" in pool-wait diagnostic message - Set AWSVMManager = None on ImportError so `import *` doesn't raise - Only delete pool registry on successful cleanup (prevents orphaned cloud resources when deletion fails) - Remove unused `time` import from aws_vm.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address second review findings - Fix pool-vnc/pool-logs/pool-exec hardcoded azureuser: read ssh_username from pool registry with backward-compatible default - Store ssh_username in VMPool dataclass and persist to registry on create - Move set_auto_shutdown after SSH is available (was racing with boot) - Fix cleanup_pool_resources: handle raw instance IDs and allocation IDs for resources without Name tags (prevents orphaned resources) - Narrow key pair exception handling: re-raise unless InvalidKeyPair.NotFound - Add TODO for restricting SSH security group to user's IP Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: restore ssh_username on registry load, fix EIP disassociate API - Add ssh_username to VMPoolRegistry.load() so it persists across process restarts (was silently reverting to "azureuser" default) - Fix disassociate_address for raw allocation IDs: look up AssociationId via describe_addresses first (disassociate_address does not accept AllocationId parameter) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e8b2e04 commit 7a27d60

11 files changed

Lines changed: 1425 additions & 198 deletions

File tree

openadapt_evals/benchmarks/vm_cli.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,44 @@ def _get_resource_group() -> str:
7676

7777

7878
RESOURCE_GROUP = _get_resource_group()
79+
80+
81+
def _get_default_cloud() -> str:
82+
"""Get default cloud provider from config."""
83+
try:
84+
from openadapt_evals.config import settings
85+
86+
return settings.cloud_provider
87+
except Exception:
88+
return "azure"
89+
90+
91+
def _get_pool_ssh_username(pool: dict) -> str:
92+
"""Get SSH username from pool registry, defaulting to azureuser for backward compat."""
93+
return pool.get("ssh_username", "azureuser")
94+
95+
96+
def _create_vm_manager(cloud: str | None = None, resource_group: str | None = None):
97+
"""Factory to create the appropriate VM manager based on cloud provider.
98+
99+
Args:
100+
cloud: Cloud provider ("azure" or "aws"). If None, uses config default.
101+
resource_group: Azure resource group (ignored for AWS).
102+
103+
Returns:
104+
VMProvider instance (AzureVMManager or AWSVMManager).
105+
"""
106+
cloud = cloud or _get_default_cloud()
107+
if cloud == "aws":
108+
from openadapt_evals.infrastructure.aws_vm import AWSVMManager
109+
110+
return AWSVMManager()
111+
else:
112+
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
113+
114+
return AzureVMManager(resource_group=resource_group or RESOURCE_GROUP)
115+
116+
79117
# Custom WAA image built from waa_deploy/Dockerfile
80118
# Uses dockurr/windows:latest as base (with proper ISO download) + WAA components
81119
DOCKER_IMAGE = "waa-auto:latest"
@@ -518,11 +556,10 @@ def cmd_pool_status(args):
518556
"""Show status of all VMs in the current pool."""
519557
init_logging()
520558

521-
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
522559
from openadapt_evals.infrastructure.pool import PoolManager
523560
from openadapt_evals.infrastructure.vm_monitor import VMMonitor, VMConfig
524561

525-
vm_manager = AzureVMManager(resource_group=RESOURCE_GROUP)
562+
vm_manager = _create_vm_manager(getattr(args, "cloud", None))
526563
manager = PoolManager(vm_manager=vm_manager, log_fn=log)
527564
pool = manager.status()
528565

@@ -592,10 +629,9 @@ def cmd_delete_pool(args):
592629
init_logging()
593630
from concurrent.futures import ThreadPoolExecutor, as_completed
594631

595-
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
596632
from openadapt_evals.infrastructure.pool import PoolManager
597633

598-
vm_manager = AzureVMManager(resource_group=RESOURCE_GROUP)
634+
vm_manager = _create_vm_manager(getattr(args, "cloud", None))
599635
manager = PoolManager(vm_manager=vm_manager, log_fn=log)
600636
pool = manager.status()
601637

@@ -642,15 +678,14 @@ def cmd_pool_create(args):
642678
Uses ThreadPoolExecutor for concurrent VM creation.
643679
"""
644680
init_logging()
645-
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
646681
from openadapt_evals.infrastructure.pool import PoolManager
647682

648683
num_workers = getattr(args, "workers", 3)
649684
auto_shutdown_hours = getattr(args, "auto_shutdown_hours", 4)
650685
use_acr = getattr(args, "use_acr", False)
651686
image_id = getattr(args, "image", None)
652687

653-
vm_manager = AzureVMManager(resource_group=RESOURCE_GROUP)
688+
vm_manager = _create_vm_manager(getattr(args, "cloud", None))
654689
manager = PoolManager(vm_manager=vm_manager, log_fn=log)
655690

656691
try:
@@ -673,13 +708,12 @@ def cmd_pool_wait(args):
673708
and the WAA server to respond.
674709
"""
675710
init_logging()
676-
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
677711
from openadapt_evals.infrastructure.pool import PoolManager
678712

679713
timeout_minutes = getattr(args, "timeout", 30)
680714
no_start = getattr(args, "no_start", False)
681715

682-
vm_manager = AzureVMManager(resource_group=RESOURCE_GROUP)
716+
vm_manager = _create_vm_manager(getattr(args, "cloud", None))
683717
manager = PoolManager(vm_manager=vm_manager, log_fn=log)
684718

685719
try:
@@ -700,15 +734,14 @@ def cmd_pool_run(args):
700734
in parallel. Collects results from all workers.
701735
"""
702736
init_logging()
703-
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
704737
from openadapt_evals.infrastructure.pool import PoolManager
705738

706739
num_tasks = getattr(args, "tasks", 10)
707740
agent = getattr(args, "agent", "navi")
708741
model = getattr(args, "model", "gpt-4o-mini")
709742
api_key = getattr(args, "api_key", None)
710743

711-
vm_manager = AzureVMManager(resource_group=RESOURCE_GROUP)
744+
vm_manager = _create_vm_manager(getattr(args, "cloud", None))
712745
manager = PoolManager(vm_manager=vm_manager, log_fn=log)
713746

714747
try:
@@ -731,10 +764,9 @@ def cmd_pool_cleanup(args):
731764
weren't properly deleted.
732765
"""
733766
init_logging()
734-
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
735767
from openadapt_evals.infrastructure.pool import PoolManager
736768

737-
vm_manager = AzureVMManager(resource_group=RESOURCE_GROUP)
769+
vm_manager = _create_vm_manager(getattr(args, "cloud", None))
738770
manager = PoolManager(vm_manager=vm_manager, log_fn=log)
739771

740772
confirm = not getattr(args, "yes", False)
@@ -752,7 +784,6 @@ def cmd_pool_auto(args):
752784
If a pool already exists, skips creation and resumes from wait → run.
753785
"""
754786
init_logging()
755-
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
756787
from openadapt_evals.infrastructure.pool import PoolManager
757788

758789
num_workers = getattr(args, "workers", 1)
@@ -763,7 +794,7 @@ def cmd_pool_auto(args):
763794
model = getattr(args, "model", "gpt-4o-mini")
764795
api_key = getattr(args, "api_key", None)
765796

766-
vm_manager = AzureVMManager(resource_group=RESOURCE_GROUP)
797+
vm_manager = _create_vm_manager(getattr(args, "cloud", None))
767798
manager = PoolManager(vm_manager=vm_manager, log_fn=log)
768799

769800
try:
@@ -820,10 +851,9 @@ def cmd_pool_pause(args):
820851
instead of recreating from scratch (~42 min). Idle cost ~$0.25/day.
821852
"""
822853
init_logging()
823-
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
824854
from openadapt_evals.infrastructure.pool import PoolManager
825855

826-
vm_manager = AzureVMManager(resource_group=RESOURCE_GROUP)
856+
vm_manager = _create_vm_manager(getattr(args, "cloud", None))
827857
manager = PoolManager(vm_manager=vm_manager, log_fn=log)
828858

829859
try:
@@ -842,12 +872,11 @@ def cmd_pool_resume(args):
842872
(~5 min vs ~42 min).
843873
"""
844874
init_logging()
845-
from openadapt_evals.infrastructure.azure_vm import AzureVMManager
846875
from openadapt_evals.infrastructure.pool import PoolManager
847876

848877
timeout_minutes = getattr(args, "timeout", 10)
849878

850-
vm_manager = AzureVMManager(resource_group=RESOURCE_GROUP)
879+
vm_manager = _create_vm_manager(getattr(args, "cloud", None))
851880
manager = PoolManager(vm_manager=vm_manager, log_fn=log)
852881

853882
try:
@@ -884,6 +913,8 @@ def cmd_pool_vnc(args):
884913
worker_name = getattr(args, "worker", None)
885914
all_workers = getattr(args, "all", False)
886915

916+
ssh_user = _get_pool_ssh_username(pool)
917+
887918
if all_workers:
888919
# Set up tunnels for all workers
889920
log("POOL-VNC", f"Setting up VNC tunnels for {len(workers)} workers...")
@@ -912,7 +943,7 @@ def cmd_pool_vnc(args):
912943
"-N",
913944
"-L",
914945
f"{local_port}:localhost:8006",
915-
f"azureuser@{ip}",
946+
f"{ssh_user}@{ip}",
916947
],
917948
stdout=subprocess.DEVNULL,
918949
stderr=subprocess.DEVNULL,
@@ -989,7 +1020,7 @@ def cmd_pool_vnc(args):
9891020
"-N",
9901021
"-L",
9911022
f"{local_port}:localhost:8006",
992-
f"azureuser@{ip}",
1023+
f"{ssh_user}@{ip}",
9931024
],
9941025
stdout=subprocess.DEVNULL,
9951026
stderr=subprocess.DEVNULL,
@@ -1042,6 +1073,7 @@ def cmd_pool_logs(args):
10421073
print("ERROR: Pool has no workers.")
10431074
return 1
10441075

1076+
ssh_user = _get_pool_ssh_username(pool)
10451077
pool_id = pool.get("pool_id", "unknown")
10461078
print(f"[pool-logs] Streaming logs from {len(workers)} workers (pool: {pool_id})")
10471079
print("[pool-logs] Press Ctrl+C to stop\n", flush=True)
@@ -1060,7 +1092,7 @@ def stream_worker_logs(worker_name: str, ip: str):
10601092
"UserKnownHostsFile=/dev/null",
10611093
"-o",
10621094
"LogLevel=ERROR",
1063-
f"azureuser@{ip}",
1095+
f"{ssh_user}@{ip}",
10641096
"docker logs -f winarena",
10651097
]
10661098
try:
@@ -1132,6 +1164,7 @@ def cmd_pool_exec(args):
11321164
log("POOL-EXEC", "ERROR: Pool has no workers.")
11331165
return 1
11341166

1167+
ssh_user = _get_pool_ssh_username(pool)
11351168
cmd = getattr(args, "cmd", None)
11361169
docker = getattr(args, "docker", False)
11371170
worker_filter = getattr(args, "worker", None)
@@ -1164,7 +1197,7 @@ def cmd_pool_exec(args):
11641197

11651198
try:
11661199
result = subprocess.run(
1167-
["ssh", *SSH_OPTS, f"azureuser@{ip}", full_cmd],
1200+
["ssh", *SSH_OPTS, f"{ssh_user}@{ip}", full_cmd],
11681201
capture_output=True,
11691202
text=True,
11701203
timeout=60,
@@ -7623,10 +7656,18 @@ def main():
76237656
p_delete = subparsers.add_parser("delete", help="Delete VM and all resources")
76247657
p_delete.set_defaults(func=cmd_delete)
76257658

7659+
# Shared --cloud argument for pool commands
7660+
_cloud_kwargs = {
7661+
"choices": ["azure", "aws"],
7662+
"default": None,
7663+
"help": "Cloud provider (default: from config, usually azure)",
7664+
}
7665+
76267666
# pool-status
76277667
p_pool_status = subparsers.add_parser(
76287668
"pool-status", help="Show status of all VMs in the current pool"
76297669
)
7670+
p_pool_status.add_argument("--cloud", **_cloud_kwargs)
76307671
p_pool_status.add_argument(
76317672
"--probe",
76327673
action="store_true",
@@ -7636,13 +7677,15 @@ def main():
76367677

76377678
# delete-pool
76387679
p_delete_pool = subparsers.add_parser("delete-pool", help="Delete all VMs in the current pool")
7680+
p_delete_pool.add_argument("--cloud", **_cloud_kwargs)
76397681
p_delete_pool.add_argument("-y", "--yes", action="store_true", help="Skip confirmation")
76407682
p_delete_pool.set_defaults(func=cmd_delete_pool)
76417683

76427684
# pool-create
76437685
p_pool_create = subparsers.add_parser(
76447686
"pool-create", help="Create a pool of VMs for parallel WAA evaluation"
76457687
)
7688+
p_pool_create.add_argument("--cloud", **_cloud_kwargs)
76467689
p_pool_create.add_argument(
76477690
"--workers",
76487691
"-n",
@@ -7671,6 +7714,7 @@ def main():
76717714
p_pool_wait = subparsers.add_parser(
76727715
"pool-wait", help="Wait for all pool workers to have WAA ready"
76737716
)
7717+
p_pool_wait.add_argument("--cloud", **_cloud_kwargs)
76747718
p_pool_wait.add_argument(
76757719
"--timeout", "-t", type=int, default=30, help="Timeout in minutes (default: 30)"
76767720
)
@@ -7685,6 +7729,7 @@ def main():
76857729
p_pool_run = subparsers.add_parser(
76867730
"pool-run", help="Run WAA benchmark tasks distributed across pool workers"
76877731
)
7732+
p_pool_run.add_argument("--cloud", **_cloud_kwargs)
76887733
p_pool_run.add_argument(
76897734
"--tasks",
76907735
"-n",
@@ -7703,6 +7748,7 @@ def main():
77037748
p_pool_cleanup = subparsers.add_parser(
77047749
"pool-cleanup", help="Clean up orphaned pool resources (VMs, NICs, IPs, disks)"
77057750
)
7751+
p_pool_cleanup.add_argument("--cloud", **_cloud_kwargs)
77067752
p_pool_cleanup.add_argument("-y", "--yes", action="store_true", help="Skip confirmation")
77077753
p_pool_cleanup.set_defaults(func=cmd_pool_cleanup)
77087754

@@ -7711,6 +7757,7 @@ def main():
77117757
"pool-auto",
77127758
help="Fully automated: create VMs → wait for WAA → run benchmark",
77137759
)
7760+
p_pool_auto.add_argument("--cloud", **_cloud_kwargs)
77147761
p_pool_auto.add_argument(
77157762
"--workers", "-w", type=int, default=1, help="Number of worker VMs (default: 1)"
77167763
)
@@ -7738,13 +7785,15 @@ def main():
77387785
"pool-pause",
77397786
help="Deallocate pool VMs (stops compute billing, keeps disks ~$0.25/day)",
77407787
)
7788+
p_pool_pause.add_argument("--cloud", **_cloud_kwargs)
77417789
p_pool_pause.set_defaults(func=cmd_pool_pause)
77427790

77437791
# pool-resume
77447792
p_pool_resume = subparsers.add_parser(
77457793
"pool-resume",
77467794
help="Resume a paused pool (start VMs, wait for WAA ~5 min)",
77477795
)
7796+
p_pool_resume.add_argument("--cloud", **_cloud_kwargs)
77487797
p_pool_resume.add_argument(
77497798
"--timeout",
77507799
"-t",

openadapt_evals/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class Settings(BaseSettings):
4545
azure_checkpoints_container: str = "checkpoints"
4646
azure_comparisons_container: str = "comparisons"
4747

48+
# Multi-cloud settings
49+
cloud_provider: str = "azure" # "azure" or "aws"
50+
aws_region: str = "us-east-1"
51+
4852
model_config = {
4953
"env_file": ".env",
5054
"env_file_encoding": "utf-8",

openadapt_evals/infrastructure/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Infrastructure components for VM management and monitoring.
22
33
This module provides:
4+
- VMProvider: Cloud-agnostic VM provider protocol
45
- AzureVMManager: Azure VM lifecycle management (SDK + CLI fallback)
6+
- AWSVMManager: AWS EC2 lifecycle management (boto3)
57
- PoolManager: Multi-VM pool orchestration
68
- VMMonitor: Azure VM status monitoring
79
- AzureOpsTracker: Azure operation logging
@@ -12,10 +14,15 @@
1214
```python
1315
from openadapt_evals.infrastructure import AzureVMManager, PoolManager
1416
15-
# Manage VMs
17+
# Manage VMs (Azure)
1618
vm = AzureVMManager()
1719
ip = vm.get_vm_ip("waa-eval-vm")
1820
21+
# Or use AWS
22+
from openadapt_evals.infrastructure import AWSVMManager
23+
vm = AWSVMManager(region="us-east-1")
24+
pool = PoolManager(vm_manager=vm)
25+
1926
# Create and manage pools
2027
pool = PoolManager()
2128
pool.create(workers=3)
@@ -42,15 +49,23 @@
4249
from openadapt_evals.infrastructure.ssh_tunnel import SSHTunnelManager, get_tunnel_manager
4350
from openadapt_evals.infrastructure.vm_ip import resolve_vm_ip
4451
from openadapt_evals.infrastructure.vm_monitor import VMMonitor, VMConfig
52+
from openadapt_evals.infrastructure.vm_provider import VMProvider
53+
54+
try:
55+
from openadapt_evals.infrastructure.aws_vm import AWSVMManager
56+
except ImportError:
57+
AWSVMManager = None # boto3 not installed; use `pip install openadapt-evals[aws]`
4558

4659
__all__ = [
60+
"AWSVMManager",
4761
"AzureOpsTracker",
4862
"AzureVMManager",
4963
"PoolManager",
5064
"PoolRunResult",
5165
"QEMUResetManager",
5266
"VMMonitor",
5367
"VMConfig",
68+
"VMProvider",
5469
"SSHTunnelManager",
5570
"compare_screenshots",
5671
"get_tunnel_manager",

0 commit comments

Comments
 (0)