diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 050ec7f28f3a..77a560ebdfe3 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -1472,26 +1472,40 @@ def scaled_dot_general( ) return out + @jax.jit + @custom_jvp + def prelu(x: ArrayLike, a: ArrayLike) -> Array: + """ + Applies the PReLU (Parametric ReLU) activation function element-wise. Similar to ReLU, but includes a slope parameter a that controls the slope for negative inputs, allowing the network to adaptively retain some negative information and prevent dead neurons. -@custom_derivatives.custom_jvp -@api.jit -def log1mexp(x: ArrayLike) -> Array: - r"""Numerically stable calculation of :math:`\log(1 - \exp(-x))`. - This function is undefined for :math:`x < 0`. - Based on `TensorFlow's implementation `_. + Args: + x (Array): Input tensor. + + + a (Array): Slope parameter. + + + Returns: + Output tensor with the PReLU activation applied. + +If x>= 0 it returns x. +Otherwise, it returns x*a. +Examples: +x = jnp.array([0.1,0.2,3,-2]) +a = 0.1 (Placeholder scalar value, the a is trained through backpropagation) +output: [ 0.1 0.2 3. -0.2] - References: - .. [1] Martin Mächler. `Accurately Computing log(1 − exp(−|a|)) Assessed by the Rmpfr package. - `_. """ - x = numpy_util.ensure_arraylike("log1mexp", x) - c = jnp.log(2.0) - return jnp.where( - x < c, - jnp.log(-jnp.expm1(-x)), - jnp.log1p(-jnp.exp(-x)), - ) -log1mexp.defjvps(lambda g, ans, x: g / jnp.expm1(x)) + x = jnp.asarray(x) + a = jnp.asarray(a) + + return jnp.maximum(x, 0) + a * jnp.minimum(x, 0) # This approach is more idiomatical because it avoids boolean masking. + #prelu.defjvps( + #lambda g, ans, x, a: lax.select(x >= 0, g, a * g), # f'(x) = 1 if x>= 0 + #lambda g, ans, x, a: lax.select(x < 0, x * g, lax.full_like(g, 0)) # f'(x) = a if x<0 +) + + diff --git a/tests/nn_test.py b/tests/nn_test.py index 7fb2fce402fb..3486ca07b0cf 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -493,7 +493,27 @@ def testRelu6Grad(self): check_grads(nn.relu6, (-1.,), order=3, rtol=rtol) self.assertAllClose(jax.grad(nn.relu6)(0.), 0., check_dtypes=False) self.assertAllClose(jax.grad(nn.relu6)(6.), 0., check_dtypes=False) - + def testPreluValue(self): + x = jnp.array([[1.0, -1.0, 2.0, -2.0]], dtype=jnp.float32) + a = jnp.array([0.1], dtype=jnp.float32) + expected = jnp.where(x >= 0, x, a * x) + + y = nn.prelu(x, a) + self.assertTrue(jnp.allclose(y, expected, rtol=1e-6, atol=1e-6)) + + y_scalar = nn.prelu(x, 0.1) + self.assertTrue(jnp.allclose(y_scalar, expected, rtol=1e-6, atol=1e-6)) + + def testPreluGrad(self): + x = jnp.array([[1.0, -1.0, 0.0, -2.0]], dtype=jnp.float32) + a = jnp.array([0.1], dtype=jnp.float32) + + def sum(x, a): + return jnp.sum(nn.prelu(x, a)) + + check_grads(sum, (x, a), order=1, modes=["fwd", "rev"]) + + def testSoftplusValue(self): val = nn.softplus(89.) self.assertAllClose(val, 89., check_dtypes=False)