@@ -3101,7 +3101,7 @@ def cond(pred):
31013101 def test_grad_of_bool_vjp3 (self ):
31023102 def cond (pred ):
31033103 return lax .cond (pred , lambda _ : 1. , lambda _ : 2. , 1. )
3104- value , f_vjp = api .vjp3 (cond , True )
3104+ value , f_vjp = api .vjp (cond , True )
31053105 grd , = f_vjp (1. )
31063106 self .assertEqual (value , 1. )
31073107 self .assertEqual (grd , np .zeros (shape = (), dtype = float0 ))
@@ -3213,6 +3213,12 @@ def f():
32133213 self .assertNotRegex (str (j_module ),
32143214 f"stablehlo.constant dense.*tensor<{ const_size } x" )
32153215
3216+ def test_basic_vjp3 (self ):
3217+ f = jax .jit (lambda x : jnp .sin (jnp .sin (x )))
3218+ _ , f_vjp = jax .vjp (f , 1. )
3219+ g , = f_vjp (1.0 )
3220+ self .assertAllClose (g , jnp .cos (jnp .sin (1. )) * jnp .cos (1. ), check_dtypes = False )
3221+
32163222 def test_constants_not_in_lowering_scan (self ):
32173223 if not config .use_simplified_jaxpr_constants .value :
32183224 self .skipTest ("Works only with simplified Jaxpr consts" )
@@ -6737,11 +6743,7 @@ def f(x):
67376743 return lax .cond (x .sum () > 0 , f , lambda x : x , x )
67386744
67396745 _ , f_vjp = api .vjp (f , jnp .ones ((5 , 5 )))
6740- if config .vjp3 .value :
6741- jaxpr_text = str (f_vjp .jaxpr )
6742- else :
6743- jaxpr_text = str (f_vjp .jaxpr )
6744-
6746+ jaxpr_text = str (f_vjp .jaxpr )
67456747 self .assertEqual (jaxpr_text .count (' sin ' ), 2 )
67466748 self .assertEqual (jaxpr_text .count (' cos ' ), 3 )
67476749 # Five calls to dot_general in the backward pass because we have two for
@@ -7806,7 +7808,7 @@ def test_basic_unused(self):
78067808 def test_basic_unused_vjp3 (self ):
78077809 f = jnp .sin
78087810 primals = 3. ,
7809- y , f_vjp = api .vjp3 (f , * primals )
7811+ y , f_vjp = api .vjp (f , * primals )
78107812 x_ct , = f_vjp (1. )
78117813 self .assertAllClose (y , jnp .sin (3. ))
78127814 self .assertAllClose (x_ct , jnp .cos (3. ))
@@ -7821,7 +7823,7 @@ def test_basic_opaque(self):
78217823 def test_basic_opaque_vjp3 (self ):
78227824 f = jnp .sin
78237825 primals = 3. ,
7824- _ , f_vjp = api .vjp3 (f , * primals )
7826+ _ , f_vjp = api .vjp (f , * primals )
78257827 assert f_vjp .opaque_residuals # can detect if opaque res are used
78267828
78277829 def test_basic_pytree_error (self ):
@@ -7841,7 +7843,7 @@ def f(x):
78417843 # def f(x):
78427844 # return [x['hi'] * x['bye']]
78437845
7844- # y, f_vjp = api.vjp3 (f, {'hi': 2., 'bye': 3.})
7846+ # y, f_vjp = api.vjp (f, {'hi': 2., 'bye': 3.})
78457847 # arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.})
78467848 # self.assertAllClose(y, [6.])
78477849 # self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.})
@@ -7890,7 +7892,7 @@ def f2(x, w):
78907892
78917893 x = jnp .ones ((3 , 4 ))
78927894 w = jnp .ones ((4 , 4 ))
7893- y , f2_vjp = api .vjp3 (f2 , x , w )
7895+ y , f2_vjp = api .vjp (f2 , x , w )
78947896 f2_vjp .args_res [1 ] = None
78957897 y_grad = jnp .ones_like (y )
78967898 f2_vjp .args_res [1 ] = w
@@ -7964,7 +7966,7 @@ def foo(x):
79647966
79657967 def test_grad_traceback (self ):
79667968 # TODO(dougalm): improve this
7967- expected_depth = 12
7969+ expected_depth = 11
79687970 init_depth = self .cur_depth ()
79697971
79707972 def foo (x ):
@@ -7987,7 +7989,7 @@ def foo(x):
79877989 def test_custom_vjp_traceback (self ):
79887990 # TODO(dougalm): improve this
79897991 expected_depth_f = 10
7990- expected_depth_f_fwd = 20
7992+ expected_depth_f_fwd = 19
79917993 expected_depth_f_rev = 12
79927994 init_depth = self .cur_depth ()
79937995 @jax .custom_vjp
0 commit comments