|
15 | 15 | import json |
16 | 16 | import logging |
17 | 17 | from unittest import mock |
| 18 | +from typing import Any |
18 | 19 |
|
19 | 20 | from absl.testing import absltest |
20 | 21 | from absl.testing import parameterized |
@@ -53,6 +54,40 @@ def setUp(self): |
53 | 54 | self.mock_original_stop_trace = self.enter_context( |
54 | 55 | mock.patch.object(profiling, "_original_stop_trace", autospec=True) |
55 | 56 | ) |
| 57 | + self.mock_datetime = self.enter_context( |
| 58 | + mock.patch.object(profiling.datetime, "datetime", autospec=True) |
| 59 | + ) |
| 60 | + self.mock_datetime.now.return_value.strftime.return_value = ( |
| 61 | + "2026_06_04_05_29_33" |
| 62 | + ) |
| 63 | + |
| 64 | + def _get_expected_profile_request( |
| 65 | + self, |
| 66 | + trace_location: str, |
| 67 | + max_num_hosts: int = 1, |
| 68 | + session_id: str = "2026_06_04_05_29_33", |
| 69 | + ) -> dict[str, Any]: |
| 70 | + if jax.version.__version_info__ >= (0, 9, 2): |
| 71 | + return { |
| 72 | + "profileRequest": { |
| 73 | + "traceLocation": trace_location, |
| 74 | + "maxNumHosts": max_num_hosts, |
| 75 | + "xprofTraceOptions": { |
| 76 | + "traceDirectory": trace_location, |
| 77 | + "pwTraceOptions": { |
| 78 | + "enablePythonTracer": True, |
| 79 | + }, |
| 80 | + "traceSessionName": session_id, |
| 81 | + }, |
| 82 | + } |
| 83 | + } |
| 84 | + else: |
| 85 | + return { |
| 86 | + "profileRequest": { |
| 87 | + "traceLocation": trace_location, |
| 88 | + "maxNumHosts": max_num_hosts, |
| 89 | + } |
| 90 | + } |
56 | 91 |
|
57 | 92 | @parameterized.parameters(8000, 1234) |
58 | 93 | def test_collect_profile_port(self, port): |
@@ -228,40 +263,67 @@ def test_start_trace_success(self): |
228 | 263 | profiling.start_trace("gs://test_bucket/test_dir") |
229 | 264 |
|
230 | 265 | self.mock_toy_computation.assert_called_once() |
| 266 | + expected_request = self._get_expected_profile_request( |
| 267 | + "gs://test_bucket/test_dir", max_num_hosts=1 |
| 268 | + ) |
231 | 269 | self.mock_plugin_executable_cls.assert_called_once_with( |
232 | | - json.dumps({ |
233 | | - "profileRequest": { |
234 | | - "traceLocation": "gs://test_bucket/test_dir", |
235 | | - "maxNumHosts": 1, |
236 | | - } |
237 | | - }) |
| 270 | + json.dumps(expected_request) |
238 | 271 | ) |
239 | 272 | self.mock_plugin_executable_cls.return_value.call.assert_called_once() |
240 | | - self.mock_original_start_trace.assert_called_once_with( |
241 | | - log_dir="gs://test_bucket/test_dir", |
242 | | - create_perfetto_link=False, |
243 | | - create_perfetto_trace=False, |
244 | | - ) |
| 273 | + self.mock_original_start_trace.assert_called_once() |
| 274 | + call_args = self.mock_original_start_trace.call_args[1] |
| 275 | + self.assertEqual(call_args["log_dir"], "gs://test_bucket/test_dir") |
| 276 | + self.assertFalse(call_args["create_perfetto_link"]) |
| 277 | + self.assertFalse(call_args["create_perfetto_trace"]) |
| 278 | + if jax.version.__version_info__ >= (0, 9, 2): |
| 279 | + self.assertEqual( |
| 280 | + call_args["profiler_options"].session_id, "2026_06_04_05_29_33" |
| 281 | + ) |
245 | 282 | self.assertIsNotNone(profiling._profile_state.executable) |
246 | 283 |
|
247 | 284 | def test_start_trace_with_max_num_hosts(self): |
248 | 285 | profiling.start_trace("gs://test_bucket/test_dir", max_num_hosts=10) |
249 | 286 |
|
250 | 287 | self.mock_toy_computation.assert_called_once() |
| 288 | + expected_request = self._get_expected_profile_request( |
| 289 | + "gs://test_bucket/test_dir", max_num_hosts=10 |
| 290 | + ) |
251 | 291 | self.mock_plugin_executable_cls.assert_called_once_with( |
252 | | - json.dumps({ |
253 | | - "profileRequest": { |
254 | | - "traceLocation": "gs://test_bucket/test_dir", |
255 | | - "maxNumHosts": 10, |
256 | | - } |
257 | | - }) |
| 292 | + json.dumps(expected_request) |
258 | 293 | ) |
259 | 294 | self.mock_plugin_executable_cls.return_value.call.assert_called_once() |
260 | | - self.mock_original_start_trace.assert_called_once_with( |
261 | | - log_dir="gs://test_bucket/test_dir", |
262 | | - create_perfetto_link=False, |
263 | | - create_perfetto_trace=False, |
| 295 | + self.mock_original_start_trace.assert_called_once() |
| 296 | + call_args = self.mock_original_start_trace.call_args[1] |
| 297 | + self.assertEqual(call_args["log_dir"], "gs://test_bucket/test_dir") |
| 298 | + self.assertFalse(call_args["create_perfetto_link"]) |
| 299 | + self.assertFalse(call_args["create_perfetto_trace"]) |
| 300 | + if jax.version.__version_info__ >= (0, 9, 2): |
| 301 | + self.assertEqual( |
| 302 | + call_args["profiler_options"].session_id, "2026_06_04_05_29_33" |
| 303 | + ) |
| 304 | + |
| 305 | + @absltest.skipIf( |
| 306 | + jax.version.__version_info__ < (0, 9, 2), |
| 307 | + "ProfileOptions requires JAX 0.9.2 or newer", |
| 308 | + ) |
| 309 | + def test_start_trace_with_session_id_in_options(self): |
| 310 | + options = jax.profiler.ProfileOptions() |
| 311 | + options.session_id = "options_session" |
| 312 | + profiling.start_trace("gs://test_bucket/test_dir", profiler_options=options) |
| 313 | + |
| 314 | + expected_request = self._get_expected_profile_request( |
| 315 | + "gs://test_bucket/test_dir", max_num_hosts=1, session_id="options_session" |
264 | 316 | ) |
| 317 | + self.mock_plugin_executable_cls.assert_called_once_with( |
| 318 | + json.dumps(expected_request) |
| 319 | + ) |
| 320 | + self.assertEqual(options.session_id, "options_session") |
| 321 | + self.mock_original_start_trace.assert_called_once() |
| 322 | + call_args = self.mock_original_start_trace.call_args[1] |
| 323 | + self.assertEqual(call_args["log_dir"], "gs://test_bucket/test_dir") |
| 324 | + self.assertFalse(call_args["create_perfetto_link"]) |
| 325 | + self.assertFalse(call_args["create_perfetto_trace"]) |
| 326 | + self.assertEqual(call_args["profiler_options"].session_id, "options_session") |
265 | 327 |
|
266 | 328 | def test_start_trace_no_toy_computation_second_time(self): |
267 | 329 | profiling.start_trace("gs://test_bucket/test_dir") |
|
0 commit comments