2828
2929@ray .remote (num_gpus = 1 )
3030class WorkerActor :
31+ """A Ray actor that owns a single GPU and benchmarks kernel configurations on it."""
3132 def __init__ (
3233 self , kernel_source , kernel_options , device_options , tuning_options , iterations , observers
3334 ):
@@ -67,7 +68,7 @@ def get_environment(self):
6768 return env
6869
6970 def run (self , params ):
70- # TODO: logging.debug("sequential runner started for " + self.kernel_options.kernel_name)
71+ # logging.debug("sequential runner started for " + self.kernel_options.kernel_name)
7172 result = None
7273
7374 # attempt to warmup the GPU by running the first config in the parameter space and ignoring the result
@@ -86,17 +87,20 @@ def run(self, params):
8687 params ["ray_actor_id" ] = ray .get_runtime_context ().get_actor_id ()
8788 params ["host_name" ] = socket .gethostname ()
8889
89- # all visited configurations are added to results to provide a trace for optimization strategies
9090 return params
9191
9292
9393class Worker :
94+ """Local handle for a ``WorkerActor`` running in a remote Ray worker process."""
95+
9496 def __init__ (self , index , actor ):
9597 self .index = index
9698 self .running_jobs = []
9799 self .maximum_running_jobs = 2
98100 self .is_running = True
99101 self .actor = actor
102+
103+ # Note: This will block until the environment is available locally
100104 self .env = ray .get (actor .get_environment .remote ())
101105
102106 def __repr__ (self ):
@@ -108,9 +112,14 @@ def __repr__(self):
108112 return f"{ self .index } "
109113
110114 def shutdown (self ):
115+ """Request the remote actor to exit and mark this handle as stopped."""
111116 if not self .is_running :
112117 return
113118
119+ # Wait until running jobs complete
120+ if self .running_jobs :
121+ ray .wait (self .running_jobs )
122+
114123 self .is_running = False
115124
116125 try :
@@ -119,11 +128,14 @@ def shutdown(self):
119128 logger .exception ("failed to request actor shutdown: worker %s" , self )
120129
121130 def submit (self , config ):
131+ """Submit a kernel configuration for benchmarking."""
122132 job = self .actor .run .remote (config )
123133 self .running_jobs .append (job )
124134 return job
125135
126136 def is_available (self ):
137+ """Return True if this worker can accept another job right now."""
138+
127139 if not self .is_running :
128140 return False
129141
@@ -139,26 +151,26 @@ def launch_workers(n, *args):
139151 workers = []
140152
141153 try :
142- # Start all actors in parallel
154+ # Start all actors in parallel since `WorkerActor.remote` does not block
143155 for _ in range (n ):
144156 actors .append (WorkerActor .remote (* args ))
145157
146- # Create `Worker` objects. This blocks until each worker is ready
158+ # Create local `Worker` objects. This blocks until each worker is ready
147159 for index , actor in enumerate (actors ):
148160 worker = Worker (index , actor )
149161 workers .append (worker )
150162 logging .info ("connected: worker %s" , worker )
151-
152- return workers
153- except :
154- # Attempt to shut down actors
163+ except Exception :
164+ # Attempt to shut down the running actors
155165 for actor in actors :
156166 try :
157167 actor .shutdown .remote ()
158168 except :
159169 logger .exception ("failed to request actor shutdown: %s" , actor )
160170 raise
161171
172+ return workers
173+
162174
163175class ParallelRunner (Runner ):
164176 def __init__ (
@@ -228,7 +240,7 @@ def shutdown(self):
228240 for worker in self .workers :
229241 try :
230242 worker .shutdown ()
231- except Exception as err :
243+ except Exception :
232244 logger .exception (f"error while shutting down worker { worker } " )
233245
234246 def available_parallelism (self ):
@@ -350,8 +362,7 @@ def run(self, parameter_space, tuning_options) -> List[Optional[dict]]:
350362 result = process_metrics (result , metrics )
351363 else :
352364 logging .error (
353- "kernel configuration {key} was skipped silently due to compile or runtime failure" ,
354- key ,
365+ f"kernel configuration { key } was skipped silently due to compile or runtime failure"
355366 )
356367
357368 # print configuration to the console
0 commit comments