Skip to content

Commit 3075641

Browse files
[XLA] Enhance set_xla_metadata to further support tagging gradient operations.
PiperOrigin-RevId: 831469459
1 parent e4a25e4 commit 3075641

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

jax/_src/xla_metadata.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,6 @@ def _target_op_to_attach_metadata(value_mlir: ir.Value) -> ir.Operation | None:
131131
op = value_mlir.owner
132132
if op is None or isinstance(op, ir.Block):
133133
return None
134-
# TODO(nbasile): Add logic for handling multiply-by-constant-1.0 ops, which
135-
# are often added by jax gradients.
136-
# [Couple this change with tagging gradient ops.]
137134
return op
138135

139136

tests/xla_metadata_test.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@ class XlaMetadataTest(jtu.JaxTestCase):
3333

3434
def _assert_metadata_appears_once_per_op(
3535
self,
36-
stable_hlo_text: str,
36+
hlo_text: str,
3737
expected_tagged_ops: list[str],
3838
metadata: dict[str, str],
3939
):
40-
attribute_strings = [f'{k} = "{v}"' for k, v in metadata.items()]
40+
attribute_strings = [f'{k}="{v}"' for k, v in metadata.items()]
4141
op_with_metadata_count = {op: 0 for op in expected_tagged_ops}
4242

43-
for line in stable_hlo_text.splitlines():
43+
for line in hlo_text.splitlines():
4444
for op in expected_tagged_ops:
45-
if (op in line and all(attr in line for attr in attribute_strings)
46-
and "mhlo.frontend_attributes = " in line):
45+
if (str(op + "(") in line and all(attr in line for attr in attribute_strings)
46+
and "frontend_attributes=" in line):
4747
op_with_metadata_count[op] += 1
4848

4949
for op in op_with_metadata_count:
@@ -53,7 +53,7 @@ def _assert_metadata_appears_once_per_op(
5353
f"Expected op '{op}' to have the metadata exactly once,"
5454
f" but found it {op_with_metadata_count[op]} times\n"
5555
f"Metadata: {metadata}\n"
56-
f"StableHLO Graph:\n\n{stable_hlo_text}",
56+
f"HLO Graph:\n\n{hlo_text}",
5757
)
5858

5959
def test_f_jitted(self):
@@ -384,7 +384,38 @@ def wrapped_fn(x):
384384
return set_xla_metadata(fn(x), **metadata)
385385

386386
x_scalar = jnp.array(0.7)
387-
text = jax.jit(wrapped_fn).lower(x_scalar).as_text()
387+
text = jax.jit(wrapped_fn).lower(x_scalar).as_text("hlo")
388+
self._assert_metadata_appears_once_per_op(
389+
text, [expected_tagged_op], metadata)
390+
391+
@parameterized.parameters(
392+
("x*x", lambda x: x * x, "add"),
393+
# TODO(b/459818130): Re-enable once stablehlo changes (cl/797055546) are on HEAD.
394+
# ("sin(x)", jnp.sin, "cosine"),
395+
("tanh(x)", jnp.tanh, "add"),
396+
("1/x", lambda x: 1 / x, "negate"),
397+
("sinc(x)", jnp.sinc, "call"),
398+
)
399+
def test_value_grad_tagging(self, name, fn, expected_tagged_op):
400+
metadata = {"test_value_grad_tagging": name}
401+
402+
@jax.custom_vjp
403+
def wrapped_fn(x):
404+
return fn(x)
405+
406+
def fwd(*args):
407+
primal_out, vjp_fn = jax.vjp(fn, *args)
408+
return primal_out, vjp_fn
409+
410+
def bwd(vjp_fn, cts_in):
411+
cts_out = vjp_fn(cts_in)
412+
cts_out = set_xla_metadata(cts_out, **metadata)
413+
return cts_out
414+
415+
wrapped_fn.defvjp(fwd, bwd)
416+
417+
x_scalar = jnp.array(0.7)
418+
text = jax.jit(jax.grad(wrapped_fn)).lower(x_scalar).as_text("hlo")
388419
self._assert_metadata_appears_once_per_op(
389420
text, [expected_tagged_op], metadata)
390421

@@ -404,7 +435,7 @@ def vmapped_fn(x_item, y_item, z_item):
404435
z_batch = rng.random((batch_size, num_rows, num_rows)).astype(np.float32)
405436
inputs = (x_batch, y_batch, z_batch)
406437

407-
text = jax.jit(vmapped_fn).lower(*inputs).as_text()
438+
text = jax.jit(vmapped_fn).lower(*inputs).as_text("hlo")
408439
self._assert_metadata_appears_once_per_op(
409440
text, ["add", "subtract"], metadata)
410441

@@ -419,7 +450,7 @@ def test_sharding_support_value_tagging(self):
419450
def wrapped_fn(x):
420451
return set_xla_metadata(x * 2.0, **metadata)
421452

422-
text = jax.jit(wrapped_fn).lower(arr).as_text()
453+
text = jax.jit(wrapped_fn).lower(arr).as_text("hlo")
423454
self._assert_metadata_appears_once_per_op(text, ["multiply"], metadata)
424455

425456
def test_scan_support_value_tagging(self):
@@ -435,7 +466,7 @@ def scan_fn(init_carry, inputs_arr):
435466
return jax.lax.scan(scan_body_val_with_metadata, init_carry, inputs_arr)
436467

437468
inputs = (jnp.array(0.0), jnp.arange(1, 4, dtype=jnp.float32))
438-
text = jax.jit(scan_fn).lower(*inputs).as_text()
469+
text = jax.jit(scan_fn).lower(*inputs).as_text("hlo")
439470
self._assert_metadata_appears_once_per_op(text, ["multiply"], metadata)
440471

441472

0 commit comments

Comments
 (0)