Skip to content

Commit 404d644

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Remove config.vjp3 and vjp3 API since it's now replaced by jax.vjp
PiperOrigin-RevId: 843971635
1 parent 8ce3512 commit 404d644

File tree

9 files changed

+29
-96
lines changed

9 files changed

+29
-96
lines changed

jax/_src/api.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2212,31 +2212,6 @@ def vjp(
22122212
fun, debug_info=debug_info("vjp", fun, primals, {}))
22132213
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
22142214

2215-
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
2216-
"""Variant of vjp() that takes an lu.WrappedFun."""
2217-
if config.vjp3.value:
2218-
return _vjp3(fun, *primals, has_aux=has_aux)
2219-
primals_flat, in_tree = tree_flatten(primals)
2220-
primals_flat = [canonicalize_value(v) if not isinstance(v, core.Tracer) else v
2221-
for v in primals_flat]
2222-
for arg in primals_flat: dispatch.check_arg(arg)
2223-
if not has_aux:
2224-
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
2225-
out_primals, vjp = ad.vjp(flat_fun, primals_flat)
2226-
out_tree = out_tree()
2227-
else:
2228-
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
2229-
out_primals, vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
2230-
out_tree, aux_tree = out_aux_trees()
2231-
out_primal_avals = map(shaped_abstractify, out_primals)
2232-
out_primal_py = tree_unflatten(out_tree, out_primals)
2233-
vjp_py = Partial(partial(_vjp_pullback_wrapper, fun.__name__,
2234-
out_primal_avals, (out_tree, in_tree)), vjp)
2235-
if not has_aux:
2236-
return out_primal_py, vjp_py
2237-
else:
2238-
return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)
2239-
22402215
@partial(api_boundary, repro_api_name="jax.experimental.saved_input_vjp")
22412216
def saved_input_vjp(f: Callable, which: Sequence[bool], *primals,
22422217
allow_unused: bool = True, allow_opaque: bool = True):
@@ -2332,12 +2307,7 @@ class RSpec:
23322307
si_vjp = saved_input_vjp
23332308

23342309

2335-
def vjp3(f, *primals, has_aux=False):
2336-
dbg = debug_info("vjp", f, primals, {})
2337-
fun = lu.wrap_init(f, debug_info=dbg)
2338-
return _vjp3(fun, *primals, has_aux=has_aux)
2339-
2340-
def _vjp3(fun, *primals, has_aux=False):
2310+
def _vjp(fun, *primals, has_aux=False):
23412311
canon = lambda x: x if isinstance(x, core.Tracer) else canonicalize_value(x)
23422312
primals = tree_map(canon, primals)
23432313
primals_flat, in_tree = tree_flatten(primals)

jax/_src/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,12 +1833,6 @@ def _validate_default_device(val):
18331833
help='Enable error checks for mutable arrays that rule out aliasing.',
18341834
include_in_trace_context=True)
18351835

1836-
vjp3 = bool_state(
1837-
name='jax_vjp3',
1838-
default=True,
1839-
upgrade=True,
1840-
help='Use new backward-pass code in jax.vjp')
1841-
18421836
refs_to_pins = bool_state(
18431837
name='jax_refs_to_pins',
18441838
default=False,

jax/_src/interpreters/ad.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from jax._src import linear_util as lu
2727
from jax._src.interpreters import partial_eval as pe
2828
from jax._src.tree_util import (tree_flatten, tree_unflatten,
29-
register_pytree_node, Partial, PyTreeDef)
29+
register_pytree_node, PyTreeDef)
3030
from jax._src import mesh as mesh_lib
3131
from jax._src import core
3232
from jax._src import source_info_util
@@ -302,23 +302,6 @@ def linearize(traceable: lu.WrappedFun, *primals, **kwargs):
302302
else:
303303
return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
304304

305-
def vjp(traceable: lu.WrappedFun, primals, has_aux=False):
306-
if not has_aux:
307-
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
308-
else:
309-
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
310-
311-
def unbound_vjp(pvals, jaxpr, consts, *cts):
312-
cts = tuple(ct for ct, pval in zip(cts, pvals) if not pval.is_known())
313-
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
314-
arg_cts = backward_pass(jaxpr, True, consts, dummy_args, cts)
315-
return map(instantiate_zeros, arg_cts)
316-
317-
vjp_ = Partial(partial(unbound_vjp, pvals, jaxpr), consts)
318-
if not has_aux:
319-
return out_primals, vjp_
320-
else:
321-
return out_primals, vjp_, aux
322305

323306
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type
324307
# errors if you will)

jax/interpreters/ad.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
primitive_jvps as primitive_jvps,
4242
primitive_transposes as primitive_transposes,
4343
reducing_transposes as reducing_transposes,
44-
vjp as vjp,
4544
zeros_like_aval as zeros_like_aval,
4645
)
4746

tests/api_test.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3101,7 +3101,7 @@ def cond(pred):
31013101
def test_grad_of_bool_vjp3(self):
31023102
def cond(pred):
31033103
return lax.cond(pred, lambda _: 1., lambda _: 2., 1.)
3104-
value, f_vjp = api.vjp3(cond, True)
3104+
value, f_vjp = api.vjp(cond, True)
31053105
grd, = f_vjp(1.)
31063106
self.assertEqual(value, 1.)
31073107
self.assertEqual(grd, np.zeros(shape=(), dtype=float0))
@@ -3213,6 +3213,12 @@ def f():
32133213
self.assertNotRegex(str(j_module),
32143214
f"stablehlo.constant dense.*tensor<{const_size}x")
32153215

3216+
def test_basic_vjp3(self):
3217+
f = jax.jit(lambda x: jnp.sin(jnp.sin(x)))
3218+
_, f_vjp = jax.vjp(f, 1.)
3219+
g, = f_vjp(1.0)
3220+
self.assertAllClose(g, jnp.cos(jnp.sin(1.)) * jnp.cos(1.), check_dtypes=False)
3221+
32163222
def test_constants_not_in_lowering_scan(self):
32173223
if not config.use_simplified_jaxpr_constants.value:
32183224
self.skipTest("Works only with simplified Jaxpr consts")
@@ -6737,11 +6743,7 @@ def f(x):
67376743
return lax.cond(x.sum() > 0, f, lambda x: x, x)
67386744

67396745
_, f_vjp = api.vjp(f, jnp.ones((5, 5)))
6740-
if config.vjp3.value:
6741-
jaxpr_text = str(f_vjp.jaxpr)
6742-
else:
6743-
jaxpr_text = str(f_vjp.jaxpr)
6744-
6746+
jaxpr_text = str(f_vjp.jaxpr)
67456747
self.assertEqual(jaxpr_text.count(' sin '), 2)
67466748
self.assertEqual(jaxpr_text.count(' cos '), 3)
67476749
# Five calls to dot_general in the backward pass because we have two for
@@ -7806,7 +7808,7 @@ def test_basic_unused(self):
78067808
def test_basic_unused_vjp3(self):
78077809
f = jnp.sin
78087810
primals = 3.,
7809-
y, f_vjp = api.vjp3(f, *primals)
7811+
y, f_vjp = api.vjp(f, *primals)
78107812
x_ct, = f_vjp(1.)
78117813
self.assertAllClose(y, jnp.sin(3.))
78127814
self.assertAllClose(x_ct, jnp.cos(3.))
@@ -7821,7 +7823,7 @@ def test_basic_opaque(self):
78217823
def test_basic_opaque_vjp3(self):
78227824
f = jnp.sin
78237825
primals = 3.,
7824-
_, f_vjp = api.vjp3(f, *primals)
7826+
_, f_vjp = api.vjp(f, *primals)
78257827
assert f_vjp.opaque_residuals # can detect if opaque res are used
78267828

78277829
def test_basic_pytree_error(self):
@@ -7841,7 +7843,7 @@ def f(x):
78417843
# def f(x):
78427844
# return [x['hi'] * x['bye']]
78437845

7844-
# y, f_vjp = api.vjp3(f, {'hi': 2., 'bye': 3.})
7846+
# y, f_vjp = api.vjp(f, {'hi': 2., 'bye': 3.})
78457847
# arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.})
78467848
# self.assertAllClose(y, [6.])
78477849
# self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.})
@@ -7890,7 +7892,7 @@ def f2(x, w):
78907892

78917893
x = jnp.ones((3, 4))
78927894
w = jnp.ones((4, 4))
7893-
y, f2_vjp = api.vjp3(f2, x, w)
7895+
y, f2_vjp = api.vjp(f2, x, w)
78947896
f2_vjp.args_res[1] = None
78957897
y_grad = jnp.ones_like(y)
78967898
f2_vjp.args_res[1] = w
@@ -7964,7 +7966,7 @@ def foo(x):
79647966

79657967
def test_grad_traceback(self):
79667968
# TODO(dougalm): improve this
7967-
expected_depth = 12
7969+
expected_depth = 11
79687970
init_depth = self.cur_depth()
79697971

79707972
def foo(x):
@@ -7987,7 +7989,7 @@ def foo(x):
79877989
def test_custom_vjp_traceback(self):
79887990
# TODO(dougalm): improve this
79897991
expected_depth_f = 10
7990-
expected_depth_f_fwd = 20
7992+
expected_depth_f_fwd = 19
79917993
expected_depth_f_rev = 12
79927994
init_depth = self.cur_depth()
79937995
@jax.custom_vjp

tests/custom_api_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1472,7 +1472,7 @@ def sin_jvp(primals, tangents):
14721472
(x, y), (x_dot, y_dot) = primals, tangents
14731473
del y_dot # ignore lol
14741474
return div(x, y), div(x_dot, y)
1475-
_, f_vjp = api.vjp3(lambda x: div(x, 2.), 1.)
1475+
_, f_vjp = api.vjp(lambda x: div(x, 2.), 1.)
14761476
ans, = f_vjp(1.)
14771477
self.assertAllClose(ans, 1./2, check_dtypes=False)
14781478

tests/lax_control_flow_test.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import jax.numpy as jnp # scan tests use numpy
4242
import jax.scipy as jsp
4343
from jax._src import dispatch
44-
from jax._src.api import vjp3
4544
from jax._src.lax import control_flow as lax_control_flow
4645
from jax._src.interpreters import batching
4746
from jax._src.interpreters import mlir
@@ -2777,10 +2776,7 @@ def cumprod(x):
27772776
# ==> Yes, we don't want to change autodiff const behavior. We must make
27782777
# these tessts pass under use_simplified_jaxpr_constants.
27792778
if not config.use_simplified_jaxpr_constants.value:
2780-
if config.vjp3.value:
2781-
ext_res, = vjp_fun.args_res
2782-
else:
2783-
*_, ext_res = vjp_fun.args[0].args[0]
2779+
ext_res, = vjp_fun.args_res
27842780
self.assertIs(ext_res, x)
27852781

27862782
if remat is not None:
@@ -2790,10 +2786,7 @@ def cumprod(x):
27902786
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not Array
27912787
_, vjp_fun = jax.vjp(cumprod, x)
27922788
if not config.use_simplified_jaxpr_constants.value:
2793-
if config.vjp3.value:
2794-
ext_res, *_ = vjp_fun.opaque_residuals
2795-
else:
2796-
*_, ext_res = vjp_fun.args[0].args[0]
2789+
ext_res, *_ = vjp_fun.opaque_residuals
27972790
self.assertIsInstance(ext_res, jax.Array)
27982791

27992792
def test_scan_vmap_collectives(self):
@@ -3498,14 +3491,14 @@ def test_cond_basic_vjp3(self):
34983491
def f(x):
34993492
return jax.lax.cond(True, jnp.sin, lambda x: x, x)
35003493

3501-
_, f_vjp = vjp3(f, 1.)
3494+
_, f_vjp = jax.vjp(f, 1.)
35023495
g, = f_vjp(1.0)
35033496
self.assertAllClose(g, jnp.cos(1.), check_dtypes=False)
35043497

35053498
def h(x):
35063499
return jax.lax.cond(True, jnp.sin, lambda x: 1., x)
35073500

3508-
_, h_vjp = vjp3(h, 1.)
3501+
_, h_vjp = jax.vjp(h, 1.)
35093502
g, = h_vjp(1.0)
35103503
self.assertAllClose(g, jnp.cos(1.), check_dtypes=False)
35113504

tests/mutable_array_test.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from jax._src import core
2626
from jax._src import config
2727
from jax._src import test_util as jtu
28-
from jax._src.api import vjp3
2928
from jax._src.util import safe_map, safe_zip
3029
from jax._src.interpreters import mlir
3130
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType
@@ -524,8 +523,8 @@ def stash_grads_bwd(grads_ref, g):
524523

525524
grads_ref = core.new_ref(jnp.float32(0.))
526525
x = jnp.float32(1.)
527-
_, f_vjp, *maybe_aux = vjp3(lambda x: primal(grads_ref, x), x,
528-
has_aux=has_aux)
526+
_, f_vjp, *maybe_aux = jax.vjp(
527+
lambda x: primal(grads_ref, x), x, has_aux=has_aux)
529528
_ = f_vjp(jnp.float32(1.))
530529
self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False)
531530
if has_aux:
@@ -553,15 +552,15 @@ def stash_grads_bwd(stash_ref, g):
553552
stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd)
554553

555554
stash_ref = core.new_ref(jnp.float32(0.))
556-
_, f_vjp = vjp3(lambda x: primal(stash_ref, x), jnp.float32(1.))
555+
_, f_vjp = jax.vjp(lambda x: primal(stash_ref, x), jnp.float32(1.))
557556
grads_val, = f_vjp(jnp.float32(1.))
558557
self.assertAllClose(stash_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False)
559558
self.assertAllClose(grads_val, jnp.cos(jnp.sin(1.)) * jnp.cos(1.),
560559
check_dtypes=False)
561560

562561
stash_ref = core.new_ref(jnp.float32(0.))
563562
grads_ref = core.new_ref(jnp.float32(0.))
564-
_, f_vjp = vjp3(lambda x: primal(stash_ref, x), jnp.float32(1.))
563+
_, f_vjp = jax.vjp(lambda x: primal(stash_ref, x), jnp.float32(1.))
565564
_ = f_vjp.with_refs(grads_ref)(jnp.float32(1.))
566565
self.assertAllClose(stash_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False)
567566
self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)) * jnp.cos(1.),
@@ -849,7 +848,7 @@ def process_batch(Ws, xs_batch):
849848
grad_acc = jax.new_ref(jnp.zeros_like(Ws)) # CHANGED
850849

851850
def process_mubatch(_, xs):
852-
loss, f_vjp = vjp3(lambda Ws: mubatch_loss(Ws, xs), Ws) # CHANGED
851+
loss, f_vjp = jax.vjp(lambda Ws: mubatch_loss(Ws, xs), Ws) # CHANGED
853852
f_vjp.with_refs(grad_acc)(jnp.ones_like(loss)) # CHANGED
854853
return (), loss
855854

@@ -924,7 +923,7 @@ def f_bwd(_, g):
924923
self.assertAllClose(y, 3.14, check_dtypes=False)
925924

926925
# this exercises the fallback path, not a fancy transpose
927-
_, f_vjp = vjp3(lambda x: f(jax.new_ref(x)), 3.14)
926+
_, f_vjp = jax.vjp(lambda x: f(jax.new_ref(x)), 3.14)
928927
g, = f_vjp(1.)
929928
self.assertAllClose(g, 1., check_dtypes=False)
930929

@@ -969,7 +968,7 @@ def body(_, xy):
969968
return z.sum()
970969

971970
grad_accum = jax.new_ref(jnp.zeros(5))
972-
_, f_vjp = vjp3(f, jnp.ones(5))
971+
_, f_vjp = jax.vjp(f, jnp.ones(5))
973972
_, = f_vjp.with_refs(grad_accum)(1.)
974973
self.assertAllClose(grad_accum[...], jnp.arange(5.))
975974

@@ -978,7 +977,7 @@ def test_vmap_with_vjp3(self):
978977
def grad_via_ref(f):
979978
def wrapper(*args):
980979
grad_accum = jax.tree.map(lambda x: jax.new_ref(jnp.zeros_like(x)), args)
981-
out, f_vjp = vjp3(f, *args)
980+
out, f_vjp = jax.vjp(f, *args)
982981
f_vjp.with_refs(*grad_accum)(jnp.ones_like(out))
983982
return jax.tree.map(lambda x: jax.freeze(x), grad_accum)
984983
return wrapper

tests/pjit_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from jax._src import dtypes
3939
from jax import stages
4040
from jax import lax
41-
from jax._src.api import vjp3
4241
from jax._src.lax import lax as lax_internal
4342
from jax.lax import with_sharding_constraint
4443
from jax._src import prng
@@ -1309,12 +1308,6 @@ def test_device_put_copy_donate(self):
13091308
self.assertNotDeleted(z)
13101309
self.assertArraysEqual(a, x * 2)
13111310

1312-
def test_basic_vjp3(self):
1313-
f = jax.jit(lambda x: jnp.sin(jnp.sin(x)))
1314-
_, f_vjp = vjp3(f, 1.)
1315-
g, = f_vjp(1.0)
1316-
self.assertAllClose(g, jnp.cos(jnp.sin(1.)) * jnp.cos(1.), check_dtypes=False)
1317-
13181311

13191312
@jtu.pytest_mark_if_available('multiaccelerator')
13201313
class AutoShardingPjitTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)