3535
3636
3737class _ProfileState :
38+ """Holds the state of an ongoing profiling session.
39+
40+ Attributes:
41+ executable: The `plugin_executable.PluginExecutable` instance used for the
42+ profiling session.
43+ profile_request: The mapping containing the profile request options.
44+ lock: A thread lock to protect access to the state.
45+ """
3846 executable : plugin_executable .PluginExecutable | None = None
3947 profile_request : Mapping [str , Any ] | None = None
4048 lock : threading .Lock
@@ -48,6 +56,37 @@ def reset(self) -> None:
4856 self .executable = None
4957 self .profile_request = None
5058
59+ def call_profile_executable (self ) -> None :
60+ """Calls the profiling executable and waits for the result."""
61+ if self .executable is None :
62+ raise RuntimeError (
63+ "_call_profile_executable called with no active executable."
64+ )
65+ # If the profile request contains xprofTraceOptions, then we need to pass
66+ # out_avals and out_shardings to the executable call because the
67+ # executable will return a future that needs to be resolved. This is true
68+ # for both starting and stopping a trace.
69+ if (
70+ self .profile_request is not None
71+ and "xprofTraceOptions" in self .profile_request
72+ ):
73+ out_avals = [jax .core .ShapedArray ((1 ,), jnp .object_ )]
74+ out_shardings = [
75+ getattr (
76+ jax .sharding ,
77+ "make_single_device_sharding" ,
78+ jax .sharding .SingleDeviceSharding ,
79+ )(jax .devices ()[0 ])
80+ ]
81+ else :
82+ out_avals = ()
83+ out_shardings = ()
84+
85+ _ , result_future = self .executable .call (
86+ out_avals = out_avals , out_shardings = out_shardings
87+ )
88+ result_future .result ()
89+
5190
5291_first_profile_start = True
5392_profile_state = _ProfileState ()
@@ -74,6 +113,7 @@ def _is_default_profile_options(
74113 == default_options .python_tracer_level
75114 and profiler_options .duration_ms == default_options .duration_ms
76115 and not getattr (profiler_options , "advanced_configuration" , None )
116+ and not getattr (profiler_options , "session_id" , None )
77117 )
78118
79119
@@ -128,6 +168,9 @@ def _create_profile_request(
128168 if pw_trace_opts :
129169 xprof_options ["pwTraceOptions" ] = pw_trace_opts
130170
171+ if getattr (profiler_options , "session_id" , None ):
172+ xprof_options ["traceSessionName" ] = profiler_options .session_id
173+
131174 profile_request ["xprofTraceOptions" ] = xprof_options
132175
133176 if profiler_options .duration_ms > 0 :
@@ -153,19 +196,30 @@ def _start_pathways_trace_from_profile_request(
153196 toy_computation ()
154197
155198 if _profile_state .executable is not None :
156- raise ValueError (
199+ raise RuntimeError (
157200 "start_trace called while a trace is already being taken!"
158201 )
159- _profile_state .executable = plugin_executable .PluginExecutable (
160- json .dumps ({"profileRequest" : profile_request })
161- )
162- _profile_state .profile_request = profile_request
163202 try :
164- _ , result_future = _profile_state .executable .call ()
165- result_future .result ()
166- except Exception :
167- _logger .exception ("Failed to start trace" )
203+ _profile_state .executable = plugin_executable .PluginExecutable (
204+ json .dumps ({"profileRequest" : profile_request })
205+ )
206+ _profile_state .profile_request = profile_request
207+ _profile_state .call_profile_executable ()
208+ except Exception as e :
168209 _profile_state .reset ()
210+ if (
211+ "xprofTraceOptions" in profile_request
212+ and "traceSessionName" in profile_request ["xprofTraceOptions" ]
213+ ):
214+ if "Bad PluginProgram" in str (e ):
215+ raise RuntimeError (
216+ "Failed to start Pathways trace. The Pathways backend rejected "
217+ "the request, likely because the running Pathways server images "
218+ "do not support the trace session ID option. Please ensure you "
219+ "are running the latest versions of both Pathways server images "
220+ "and the pathwaysutils library."
221+ ) from e
222+ _logger .exception ("Failed to start trace" )
169223 raise
170224
171225
@@ -243,28 +297,9 @@ def stop_trace() -> None:
243297 try :
244298 with _profile_state .lock :
245299 if _profile_state .executable is None :
246- raise ValueError ("stop_trace called before a trace is being taken!" )
300+ raise RuntimeError ("stop_trace called before a trace is being taken!" )
247301 try :
248- if (
249- _profile_state .profile_request is not None
250- and "xprofTraceOptions" in _profile_state .profile_request
251- ):
252- out_avals = [jax .core .ShapedArray ((1 ,), jnp .object_ )]
253- out_shardings = [
254- getattr (
255- jax .sharding ,
256- "make_single_device_sharding" ,
257- lambda x : jax .sharding .SingleDeviceSharding (x ),
258- )(jax .devices ()[0 ])
259- ]
260- else :
261- out_avals = ()
262- out_shardings = ()
263-
264- _ , result_future = _profile_state .executable .call (
265- out_avals = out_avals , out_shardings = out_shardings
266- )
267- result_future .result ()
302+ _profile_state .call_profile_executable ()
268303 finally :
269304 _profile_state .reset ()
270305 finally :
@@ -306,7 +341,7 @@ async def profiling(pc: ProfilingConfig) -> Mapping[str, str]:
306341
307342 global _profiler_thread
308343 if _profiler_thread is not None :
309- raise ValueError ("Only one profiler server can be active at a time." )
344+ raise RuntimeError ("Only one profiler server can be active at a time." )
310345
311346 _profiler_thread = threading .Thread (target = server_loop , args = (port ,))
312347 _profiler_thread .start ()
@@ -318,7 +353,7 @@ def stop_server() -> None:
318353 Pathways profiling servers are not stoppable at this time.
319354 """
320355 if _profiler_thread is None :
321- raise ValueError ("No active profiler server." )
356+ raise RuntimeError ("No active profiler server." )
322357
323358
324359def collect_profile (
0 commit comments