Skip to content

Commit bce3eef

Browse files
lukebaumanncopybara-github
authored andcommitted
Patch internal JAX profiler functions (enabling jax.profiler.trace) and add a test for jax.profiler.trace.
The `jax.profiler.trace` context manager uses internal `jax._src.profiler` functions. This change ensures that these internal functions are also patched by `pathwaysutils.profiling.monkey_patch_jax` to correctly intercept profiling calls. A new test is added to verify that `with jax.profiler.trace(...)` now triggers the patched Pathways profiling functions. PiperOrigin-RevId: 845135750
1 parent b24bfec commit bce3eef

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

pathwaysutils/profiling.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def start_trace(
104104
*,
105105
create_perfetto_link: bool = False,
106106
create_perfetto_trace: bool = False,
107+
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
107108
) -> None:
108109
"""Starts a profiler trace.
109110
@@ -131,6 +132,8 @@ def start_trace(
131132
want to generate a Perfetto-compatible trace without blocking the process.
132133
This feature is experimental for Pathways on Cloud and may not be fully
133134
supported.
135+
profiler_options: Profiler options to configure the profiler for collection.
136+
Options are not currently supported and ignored.
134137
"""
135138
if not str(log_dir).startswith("gs://"):
136139
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
@@ -270,19 +273,27 @@ def monkey_patch_jax():
270273

271274
def start_trace_patch(
272275
log_dir,
273-
create_perfetto_link: bool = False, # pylint: disable=unused-argument
274-
create_perfetto_trace: bool = False, # pylint: disable=unused-argument
276+
create_perfetto_link: bool = False,
277+
create_perfetto_trace: bool = False,
278+
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
275279
) -> None:
276280
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
277-
return start_trace(log_dir)
281+
return start_trace(
282+
log_dir,
283+
create_perfetto_link=create_perfetto_link,
284+
create_perfetto_trace=create_perfetto_trace,
285+
profiler_options=profiler_options,
286+
)
278287

279288
jax.profiler.start_trace = start_trace_patch
289+
jax._src.profiler.start_trace = start_trace_patch # pylint: disable=protected-access
280290

281291
def stop_trace_patch() -> None:
282292
_logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
283293
return stop_trace()
284294

285295
jax.profiler.stop_trace = stop_trace_patch
296+
jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access
286297

287298
def start_server_patch(port: int):
288299
_logger.debug(

pathwaysutils/test/profiling_test.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,18 @@ def test_monkey_patch_jax(self):
310310
original_jax_stop_trace = jax.profiler.stop_trace
311311
original_jax_start_server = jax.profiler.start_server
312312
original_jax_stop_server = jax.profiler.stop_server
313+
self.addCleanup(
314+
setattr, jax.profiler, "start_trace", original_jax_start_trace
315+
)
316+
self.addCleanup(
317+
setattr, jax.profiler, "stop_trace", original_jax_stop_trace
318+
)
319+
self.addCleanup(
320+
setattr, jax.profiler, "start_server", original_jax_start_server
321+
)
322+
self.addCleanup(
323+
setattr, jax.profiler, "stop_server", original_jax_stop_server
324+
)
313325

314326
profiling.monkey_patch_jax()
315327

@@ -322,7 +334,12 @@ def test_monkey_patch_jax(self):
322334
profiling, "start_trace", autospec=True
323335
) as mock_pw_start_trace:
324336
jax.profiler.start_trace("gs://bucket/dir")
325-
mock_pw_start_trace.assert_called_once_with("gs://bucket/dir")
337+
mock_pw_start_trace.assert_called_once_with(
338+
"gs://bucket/dir",
339+
create_perfetto_link=False,
340+
create_perfetto_trace=False,
341+
profiler_options=None,
342+
)
326343

327344
with mock.patch.object(
328345
profiling, "stop_trace", autospec=True
@@ -342,12 +359,6 @@ def test_monkey_patch_jax(self):
342359
jax.profiler.stop_server()
343360
mock_pw_stop_server.assert_called_once()
344361

345-
# Restore original jax functions
346-
jax.profiler.start_trace = original_jax_start_trace
347-
jax.profiler.stop_trace = original_jax_stop_trace
348-
jax.profiler.start_server = original_jax_start_server
349-
jax.profiler.stop_server = original_jax_stop_server
350-
351362
def test_create_profile_request_no_options(self):
352363
request = profiling._create_profile_request("gs://bucket/dir")
353364
self.assertEqual(request, {"traceLocation": "gs://bucket/dir"})
@@ -384,6 +395,7 @@ def test_create_profile_request_no_options(self):
384395
},
385396
},),
386397
)
398+
387399
def test_start_pathways_trace_from_profile_request(self, profile_request):
388400
profiling._start_pathways_trace_from_profile_request(profile_request)
389401

@@ -407,6 +419,29 @@ def test_original_stop_trace_called_on_stop_failure(self):
407419
profiling.stop_trace()
408420
self.mock_original_stop_trace.assert_called_once()
409421

422+
def test_jax_profiler_trace_calls_patched_functions(self):
423+
original_jax_start_trace = jax.profiler.start_trace
424+
original_jax_stop_trace = jax.profiler.stop_trace
425+
self.addCleanup(
426+
setattr, jax.profiler, "start_trace", original_jax_start_trace
427+
)
428+
self.addCleanup(
429+
setattr, jax.profiler, "stop_trace", original_jax_stop_trace
430+
)
431+
mock_pw_start_trace = self.enter_context(
432+
mock.patch.object(profiling, "start_trace", autospec=True)
433+
)
434+
mock_pw_stop_trace = self.enter_context(
435+
mock.patch.object(profiling, "stop_trace", autospec=True)
436+
)
437+
profiling.monkey_patch_jax()
438+
439+
with jax.profiler.trace("gs://bucket/dir"):
440+
pass
441+
442+
mock_pw_start_trace.assert_called_once()
443+
mock_pw_stop_trace.assert_called_once()
444+
410445

411446
if __name__ == "__main__":
412447
absltest.main()

0 commit comments

Comments
 (0)