@@ -76,6 +76,44 @@ def _get_resource_group() -> str:
7676
7777
7878RESOURCE_GROUP = _get_resource_group ()
79+
80+
81+ def _get_default_cloud () -> str :
82+ """Get default cloud provider from config."""
83+ try :
84+ from openadapt_evals .config import settings
85+
86+ return settings .cloud_provider
87+ except Exception :
88+ return "azure"
89+
90+
91+ def _get_pool_ssh_username (pool : dict ) -> str :
92+ """Get SSH username from pool registry, defaulting to azureuser for backward compat."""
93+ return pool .get ("ssh_username" , "azureuser" )
94+
95+
96+ def _create_vm_manager (cloud : str | None = None , resource_group : str | None = None ):
97+ """Factory to create the appropriate VM manager based on cloud provider.
98+
99+ Args:
100+ cloud: Cloud provider ("azure" or "aws"). If None, uses config default.
101+ resource_group: Azure resource group (ignored for AWS).
102+
103+ Returns:
104+ VMProvider instance (AzureVMManager or AWSVMManager).
105+ """
106+ cloud = cloud or _get_default_cloud ()
107+ if cloud == "aws" :
108+ from openadapt_evals .infrastructure .aws_vm import AWSVMManager
109+
110+ return AWSVMManager ()
111+ else :
112+ from openadapt_evals .infrastructure .azure_vm import AzureVMManager
113+
114+ return AzureVMManager (resource_group = resource_group or RESOURCE_GROUP )
115+
116+
79117# Custom WAA image built from waa_deploy/Dockerfile
80118# Uses dockurr/windows:latest as base (with proper ISO download) + WAA components
81119DOCKER_IMAGE = "waa-auto:latest"
@@ -518,11 +556,10 @@ def cmd_pool_status(args):
518556 """Show status of all VMs in the current pool."""
519557 init_logging ()
520558
521- from openadapt_evals .infrastructure .azure_vm import AzureVMManager
522559 from openadapt_evals .infrastructure .pool import PoolManager
523560 from openadapt_evals .infrastructure .vm_monitor import VMMonitor , VMConfig
524561
525- vm_manager = AzureVMManager ( resource_group = RESOURCE_GROUP )
562+ vm_manager = _create_vm_manager ( getattr ( args , "cloud" , None ) )
526563 manager = PoolManager (vm_manager = vm_manager , log_fn = log )
527564 pool = manager .status ()
528565
@@ -592,10 +629,9 @@ def cmd_delete_pool(args):
592629 init_logging ()
593630 from concurrent .futures import ThreadPoolExecutor , as_completed
594631
595- from openadapt_evals .infrastructure .azure_vm import AzureVMManager
596632 from openadapt_evals .infrastructure .pool import PoolManager
597633
598- vm_manager = AzureVMManager ( resource_group = RESOURCE_GROUP )
634+ vm_manager = _create_vm_manager ( getattr ( args , "cloud" , None ) )
599635 manager = PoolManager (vm_manager = vm_manager , log_fn = log )
600636 pool = manager .status ()
601637
@@ -642,15 +678,14 @@ def cmd_pool_create(args):
642678 Uses ThreadPoolExecutor for concurrent VM creation.
643679 """
644680 init_logging ()
645- from openadapt_evals .infrastructure .azure_vm import AzureVMManager
646681 from openadapt_evals .infrastructure .pool import PoolManager
647682
648683 num_workers = getattr (args , "workers" , 3 )
649684 auto_shutdown_hours = getattr (args , "auto_shutdown_hours" , 4 )
650685 use_acr = getattr (args , "use_acr" , False )
651686 image_id = getattr (args , "image" , None )
652687
653- vm_manager = AzureVMManager ( resource_group = RESOURCE_GROUP )
688+ vm_manager = _create_vm_manager ( getattr ( args , "cloud" , None ) )
654689 manager = PoolManager (vm_manager = vm_manager , log_fn = log )
655690
656691 try :
@@ -673,13 +708,12 @@ def cmd_pool_wait(args):
673708 and the WAA server to respond.
674709 """
675710 init_logging ()
676- from openadapt_evals .infrastructure .azure_vm import AzureVMManager
677711 from openadapt_evals .infrastructure .pool import PoolManager
678712
679713 timeout_minutes = getattr (args , "timeout" , 30 )
680714 no_start = getattr (args , "no_start" , False )
681715
682- vm_manager = AzureVMManager ( resource_group = RESOURCE_GROUP )
716+ vm_manager = _create_vm_manager ( getattr ( args , "cloud" , None ) )
683717 manager = PoolManager (vm_manager = vm_manager , log_fn = log )
684718
685719 try :
@@ -700,15 +734,14 @@ def cmd_pool_run(args):
700734 in parallel. Collects results from all workers.
701735 """
702736 init_logging ()
703- from openadapt_evals .infrastructure .azure_vm import AzureVMManager
704737 from openadapt_evals .infrastructure .pool import PoolManager
705738
706739 num_tasks = getattr (args , "tasks" , 10 )
707740 agent = getattr (args , "agent" , "navi" )
708741 model = getattr (args , "model" , "gpt-4o-mini" )
709742 api_key = getattr (args , "api_key" , None )
710743
711- vm_manager = AzureVMManager ( resource_group = RESOURCE_GROUP )
744+ vm_manager = _create_vm_manager ( getattr ( args , "cloud" , None ) )
712745 manager = PoolManager (vm_manager = vm_manager , log_fn = log )
713746
714747 try :
@@ -731,10 +764,9 @@ def cmd_pool_cleanup(args):
731764 weren't properly deleted.
732765 """
733766 init_logging ()
734- from openadapt_evals .infrastructure .azure_vm import AzureVMManager
735767 from openadapt_evals .infrastructure .pool import PoolManager
736768
737- vm_manager = AzureVMManager ( resource_group = RESOURCE_GROUP )
769+ vm_manager = _create_vm_manager ( getattr ( args , "cloud" , None ) )
738770 manager = PoolManager (vm_manager = vm_manager , log_fn = log )
739771
740772 confirm = not getattr (args , "yes" , False )
@@ -752,7 +784,6 @@ def cmd_pool_auto(args):
752784 If a pool already exists, skips creation and resumes from wait → run.
753785 """
754786 init_logging ()
755- from openadapt_evals .infrastructure .azure_vm import AzureVMManager
756787 from openadapt_evals .infrastructure .pool import PoolManager
757788
758789 num_workers = getattr (args , "workers" , 1 )
@@ -763,7 +794,7 @@ def cmd_pool_auto(args):
763794 model = getattr (args , "model" , "gpt-4o-mini" )
764795 api_key = getattr (args , "api_key" , None )
765796
766- vm_manager = AzureVMManager ( resource_group = RESOURCE_GROUP )
797+ vm_manager = _create_vm_manager ( getattr ( args , "cloud" , None ) )
767798 manager = PoolManager (vm_manager = vm_manager , log_fn = log )
768799
769800 try :
@@ -820,10 +851,9 @@ def cmd_pool_pause(args):
820851 instead of recreating from scratch (~42 min). Idle cost ~$0.25/day.
821852 """
822853 init_logging ()
823- from openadapt_evals .infrastructure .azure_vm import AzureVMManager
824854 from openadapt_evals .infrastructure .pool import PoolManager
825855
826- vm_manager = AzureVMManager ( resource_group = RESOURCE_GROUP )
856+ vm_manager = _create_vm_manager ( getattr ( args , "cloud" , None ) )
827857 manager = PoolManager (vm_manager = vm_manager , log_fn = log )
828858
829859 try :
@@ -842,12 +872,11 @@ def cmd_pool_resume(args):
842872 (~5 min vs ~42 min).
843873 """
844874 init_logging ()
845- from openadapt_evals .infrastructure .azure_vm import AzureVMManager
846875 from openadapt_evals .infrastructure .pool import PoolManager
847876
848877 timeout_minutes = getattr (args , "timeout" , 10 )
849878
850- vm_manager = AzureVMManager ( resource_group = RESOURCE_GROUP )
879+ vm_manager = _create_vm_manager ( getattr ( args , "cloud" , None ) )
851880 manager = PoolManager (vm_manager = vm_manager , log_fn = log )
852881
853882 try :
@@ -884,6 +913,8 @@ def cmd_pool_vnc(args):
884913 worker_name = getattr (args , "worker" , None )
885914 all_workers = getattr (args , "all" , False )
886915
916+ ssh_user = _get_pool_ssh_username (pool )
917+
887918 if all_workers :
888919 # Set up tunnels for all workers
889920 log ("POOL-VNC" , f"Setting up VNC tunnels for { len (workers )} workers..." )
@@ -912,7 +943,7 @@ def cmd_pool_vnc(args):
912943 "-N" ,
913944 "-L" ,
914945 f"{ local_port } :localhost:8006" ,
915- f"azureuser @{ ip } " ,
946+ f"{ ssh_user } @{ ip } " ,
916947 ],
917948 stdout = subprocess .DEVNULL ,
918949 stderr = subprocess .DEVNULL ,
@@ -989,7 +1020,7 @@ def cmd_pool_vnc(args):
9891020 "-N" ,
9901021 "-L" ,
9911022 f"{ local_port } :localhost:8006" ,
992- f"azureuser @{ ip } " ,
1023+ f"{ ssh_user } @{ ip } " ,
9931024 ],
9941025 stdout = subprocess .DEVNULL ,
9951026 stderr = subprocess .DEVNULL ,
@@ -1042,6 +1073,7 @@ def cmd_pool_logs(args):
10421073 print ("ERROR: Pool has no workers." )
10431074 return 1
10441075
1076+ ssh_user = _get_pool_ssh_username (pool )
10451077 pool_id = pool .get ("pool_id" , "unknown" )
10461078 print (f"[pool-logs] Streaming logs from { len (workers )} workers (pool: { pool_id } )" )
10471079 print ("[pool-logs] Press Ctrl+C to stop\n " , flush = True )
@@ -1060,7 +1092,7 @@ def stream_worker_logs(worker_name: str, ip: str):
10601092 "UserKnownHostsFile=/dev/null" ,
10611093 "-o" ,
10621094 "LogLevel=ERROR" ,
1063- f"azureuser @{ ip } " ,
1095+ f"{ ssh_user } @{ ip } " ,
10641096 "docker logs -f winarena" ,
10651097 ]
10661098 try :
@@ -1132,6 +1164,7 @@ def cmd_pool_exec(args):
11321164 log ("POOL-EXEC" , "ERROR: Pool has no workers." )
11331165 return 1
11341166
1167+ ssh_user = _get_pool_ssh_username (pool )
11351168 cmd = getattr (args , "cmd" , None )
11361169 docker = getattr (args , "docker" , False )
11371170 worker_filter = getattr (args , "worker" , None )
@@ -1164,7 +1197,7 @@ def cmd_pool_exec(args):
11641197
11651198 try :
11661199 result = subprocess .run (
1167- ["ssh" , * SSH_OPTS , f"azureuser @{ ip } " , full_cmd ],
1200+ ["ssh" , * SSH_OPTS , f"{ ssh_user } @{ ip } " , full_cmd ],
11681201 capture_output = True ,
11691202 text = True ,
11701203 timeout = 60 ,
@@ -7623,10 +7656,18 @@ def main():
76237656 p_delete = subparsers .add_parser ("delete" , help = "Delete VM and all resources" )
76247657 p_delete .set_defaults (func = cmd_delete )
76257658
7659+ # Shared --cloud argument for pool commands
7660+ _cloud_kwargs = {
7661+ "choices" : ["azure" , "aws" ],
7662+ "default" : None ,
7663+ "help" : "Cloud provider (default: from config, usually azure)" ,
7664+ }
7665+
76267666 # pool-status
76277667 p_pool_status = subparsers .add_parser (
76287668 "pool-status" , help = "Show status of all VMs in the current pool"
76297669 )
7670+ p_pool_status .add_argument ("--cloud" , ** _cloud_kwargs )
76307671 p_pool_status .add_argument (
76317672 "--probe" ,
76327673 action = "store_true" ,
@@ -7636,13 +7677,15 @@ def main():
76367677
76377678 # delete-pool
76387679 p_delete_pool = subparsers .add_parser ("delete-pool" , help = "Delete all VMs in the current pool" )
7680+ p_delete_pool .add_argument ("--cloud" , ** _cloud_kwargs )
76397681 p_delete_pool .add_argument ("-y" , "--yes" , action = "store_true" , help = "Skip confirmation" )
76407682 p_delete_pool .set_defaults (func = cmd_delete_pool )
76417683
76427684 # pool-create
76437685 p_pool_create = subparsers .add_parser (
76447686 "pool-create" , help = "Create a pool of VMs for parallel WAA evaluation"
76457687 )
7688+ p_pool_create .add_argument ("--cloud" , ** _cloud_kwargs )
76467689 p_pool_create .add_argument (
76477690 "--workers" ,
76487691 "-n" ,
@@ -7671,6 +7714,7 @@ def main():
76717714 p_pool_wait = subparsers .add_parser (
76727715 "pool-wait" , help = "Wait for all pool workers to have WAA ready"
76737716 )
7717+ p_pool_wait .add_argument ("--cloud" , ** _cloud_kwargs )
76747718 p_pool_wait .add_argument (
76757719 "--timeout" , "-t" , type = int , default = 30 , help = "Timeout in minutes (default: 30)"
76767720 )
@@ -7685,6 +7729,7 @@ def main():
76857729 p_pool_run = subparsers .add_parser (
76867730 "pool-run" , help = "Run WAA benchmark tasks distributed across pool workers"
76877731 )
7732+ p_pool_run .add_argument ("--cloud" , ** _cloud_kwargs )
76887733 p_pool_run .add_argument (
76897734 "--tasks" ,
76907735 "-n" ,
@@ -7703,6 +7748,7 @@ def main():
77037748 p_pool_cleanup = subparsers .add_parser (
77047749 "pool-cleanup" , help = "Clean up orphaned pool resources (VMs, NICs, IPs, disks)"
77057750 )
7751+ p_pool_cleanup .add_argument ("--cloud" , ** _cloud_kwargs )
77067752 p_pool_cleanup .add_argument ("-y" , "--yes" , action = "store_true" , help = "Skip confirmation" )
77077753 p_pool_cleanup .set_defaults (func = cmd_pool_cleanup )
77087754
@@ -7711,6 +7757,7 @@ def main():
77117757 "pool-auto" ,
77127758 help = "Fully automated: create VMs → wait for WAA → run benchmark" ,
77137759 )
7760+ p_pool_auto .add_argument ("--cloud" , ** _cloud_kwargs )
77147761 p_pool_auto .add_argument (
77157762 "--workers" , "-w" , type = int , default = 1 , help = "Number of worker VMs (default: 1)"
77167763 )
@@ -7738,13 +7785,15 @@ def main():
77387785 "pool-pause" ,
77397786 help = "Deallocate pool VMs (stops compute billing, keeps disks ~$0.25/day)" ,
77407787 )
7788+ p_pool_pause .add_argument ("--cloud" , ** _cloud_kwargs )
77417789 p_pool_pause .set_defaults (func = cmd_pool_pause )
77427790
77437791 # pool-resume
77447792 p_pool_resume = subparsers .add_parser (
77457793 "pool-resume" ,
77467794 help = "Resume a paused pool (start VMs, wait for WAA ~5 min)" ,
77477795 )
7796+ p_pool_resume .add_argument ("--cloud" , ** _cloud_kwargs )
77487797 p_pool_resume .add_argument (
77497798 "--timeout" ,
77507799 "-t" ,
0 commit comments