@@ -33,17 +33,17 @@ class XlaMetadataTest(jtu.JaxTestCase):
3333
3434 def _assert_metadata_appears_once_per_op (
3535 self ,
36- stable_hlo_text : str ,
36+ hlo_text : str ,
3737 expected_tagged_ops : list [str ],
3838 metadata : dict [str , str ],
3939 ):
40- attribute_strings = [f'{ k } = "{ v } "' for k , v in metadata .items ()]
40+ attribute_strings = [f'{ k } = "{ v } "' for k , v in metadata .items ()]
4141 op_with_metadata_count = {op : 0 for op in expected_tagged_ops }
4242
43- for line in stable_hlo_text .splitlines ():
43+ for line in hlo_text .splitlines ():
4444 for op in expected_tagged_ops :
45- if (op in line and all (attr in line for attr in attribute_strings )
46- and "mhlo. frontend_attributes = " in line ):
45+ if (str ( op + "(" ) in line and all (attr in line for attr in attribute_strings )
46+ and "frontend_attributes= " in line ):
4747 op_with_metadata_count [op ] += 1
4848
4949 for op in op_with_metadata_count :
@@ -53,7 +53,7 @@ def _assert_metadata_appears_once_per_op(
5353 f"Expected op '{ op } ' to have the metadata exactly once,"
5454 f" but found it { op_with_metadata_count [op ]} times\n "
5555 f"Metadata: { metadata } \n "
56- f"StableHLO Graph:\n \n { stable_hlo_text } " ,
56+ f"HLO Graph:\n \n { hlo_text } " ,
5757 )
5858
5959 def test_f_jitted (self ):
@@ -384,7 +384,38 @@ def wrapped_fn(x):
384384 return set_xla_metadata (fn (x ), ** metadata )
385385
386386 x_scalar = jnp .array (0.7 )
387- text = jax .jit (wrapped_fn ).lower (x_scalar ).as_text ()
387+ text = jax .jit (wrapped_fn ).lower (x_scalar ).as_text ("hlo" )
388+ self ._assert_metadata_appears_once_per_op (
389+ text , [expected_tagged_op ], metadata )
390+
391+ @parameterized .parameters (
392+ ("x*x" , lambda x : x * x , "add" ),
393+ # TODO(b/459818130): Re-enable once stablehlo changes (cl/797055546) are on HEAD.
394+ # ("sin(x)", jnp.sin, "cosine"),
395+ ("tanh(x)" , jnp .tanh , "add" ),
396+ ("1/x" , lambda x : 1 / x , "negate" ),
397+ ("sinc(x)" , jnp .sinc , "call" ),
398+ )
399+ def test_value_grad_tagging (self , name , fn , expected_tagged_op ):
400+ metadata = {"test_value_grad_tagging" : name }
401+
402+ @jax .custom_vjp
403+ def wrapped_fn (x ):
404+ return fn (x )
405+
406+ def fwd (* args ):
407+ primal_out , vjp_fn = jax .vjp (fn , * args )
408+ return primal_out , vjp_fn
409+
410+ def bwd (vjp_fn , cts_in ):
411+ cts_out = vjp_fn (cts_in )
412+ cts_out = set_xla_metadata (cts_out , ** metadata )
413+ return cts_out
414+
415+ wrapped_fn .defvjp (fwd , bwd )
416+
417+ x_scalar = jnp .array (0.7 )
418+ text = jax .jit (jax .grad (wrapped_fn )).lower (x_scalar ).as_text ("hlo" )
388419 self ._assert_metadata_appears_once_per_op (
389420 text , [expected_tagged_op ], metadata )
390421
@@ -404,7 +435,7 @@ def vmapped_fn(x_item, y_item, z_item):
404435 z_batch = rng .random ((batch_size , num_rows , num_rows )).astype (np .float32 )
405436 inputs = (x_batch , y_batch , z_batch )
406437
407- text = jax .jit (vmapped_fn ).lower (* inputs ).as_text ()
438+ text = jax .jit (vmapped_fn ).lower (* inputs ).as_text ("hlo" )
408439 self ._assert_metadata_appears_once_per_op (
409440 text , ["add" , "subtract" ], metadata )
410441
@@ -419,7 +450,7 @@ def test_sharding_support_value_tagging(self):
419450 def wrapped_fn (x ):
420451 return set_xla_metadata (x * 2.0 , ** metadata )
421452
422- text = jax .jit (wrapped_fn ).lower (arr ).as_text ()
453+ text = jax .jit (wrapped_fn ).lower (arr ).as_text ("hlo" )
423454 self ._assert_metadata_appears_once_per_op (text , ["multiply" ], metadata )
424455
425456 def test_scan_support_value_tagging (self ):
@@ -435,7 +466,7 @@ def scan_fn(init_carry, inputs_arr):
435466 return jax .lax .scan (scan_body_val_with_metadata , init_carry , inputs_arr )
436467
437468 inputs = (jnp .array (0.0 ), jnp .arange (1 , 4 , dtype = jnp .float32 ))
438- text = jax .jit (scan_fn ).lower (* inputs ).as_text ()
469+ text = jax .jit (scan_fn ).lower (* inputs ).as_text ("hlo" )
439470 self ._assert_metadata_appears_once_per_op (text , ["multiply" ], metadata )
440471
441472
0 commit comments