Skip to content

Commit a8983f8

Browse files
lukebaumannGoogle-ML-Automation
authored andcommitted
Expose profiler advanced configuration as a Python dict.
In profiler.cc, the advanced_configuration property of tensorflow::ProfileOptions is now exposed as a Python dictionary. The getter converts the proto map to a nb::dict, handling different value types (bool, int64, string). Example error: ``` ProfileOptions().advanced_configuration ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: Unable to convert function return value to a Python type! The signature was (self) -> proto2::Map<std::__u::basic_string<char, std::__u::char_traits<char>, std::__u::allocator<char>>, tensorflow::ProfileOptions_AdvancedConfigValue> ``` PiperOrigin-RevId: 844865140
1 parent 74174af commit a8983f8

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/profiler_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import jax._src.test_util as jtu
3333

3434
from jax._src import profiler
35+
from jax._src.lib import ifrt_version
3536
from jax import jit
3637

3738

@@ -508,5 +509,20 @@ def on_profile():
508509
unittest.mock.ANY,
509510
)
510511

512+
def test_advanced_configuration_getter(self):
513+
if ifrt_version < 41:
514+
self.skipTest("advanced_configuration getter is newly added")
515+
516+
options = jax.profiler.ProfileOptions()
517+
advanced_config = {
518+
"tpu_trace_mode": "TRACE_COMPUTE",
519+
"tpu_num_sparse_cores_to_trace": 1,
520+
"enableFwThrottleEvent": True,
521+
}
522+
options.advanced_configuration = advanced_config
523+
returned_config = options.advanced_configuration
524+
self.assertDictEqual(returned_config, advanced_config)
525+
526+
511527
if __name__ == "__main__":
512528
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)