Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,26 +1472,40 @@ def scaled_dot_general(
)

return out
@jax.jit
@custom_jvp
def prelu(x: ArrayLike, a: ArrayLike) -> Array:
"""
Applies the PReLU (Parametric ReLU) activation function element-wise. Similar to ReLU, but includes a slope parameter a that controls the slope for negative inputs, allowing the network to adaptively retain some negative information and prevent dead neurons.

@custom_derivatives.custom_jvp
@api.jit
def log1mexp(x: ArrayLike) -> Array:
r"""Numerically stable calculation of :math:`\log(1 - \exp(-x))`.

This function is undefined for :math:`x < 0`.

Based on `TensorFlow's implementation <https://www.tensorflow.org/probability/api_docs/python/tfp/math/log1mexp>`_.
Args:
x (Array): Input tensor.


a (Array): Slope parameter.


Returns:
Output tensor with the PReLU activation applied.

If x>= 0 it returns x.
Otherwise, it returns x*a.
Examples:
x = jnp.array([0.1,0.2,3,-2])
a = 0.1 (Placeholder scalar value, the a is trained through backpropagation)
output: [ 0.1 0.2 3. -0.2]

References:
.. [1] Martin Mächler. `Accurately Computing log(1 − exp(−|a|)) Assessed by the Rmpfr package.
<https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf>`_.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add an example or two here.

x = numpy_util.ensure_arraylike("log1mexp", x)
c = jnp.log(2.0)
return jnp.where(
x < c,
jnp.log(-jnp.expm1(-x)),
jnp.log1p(-jnp.exp(-x)),
)

log1mexp.defjvps(lambda g, ans, x: g / jnp.expm1(x))
x = jnp.asarray(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use check_arraylike before passing these to asarray

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had actually used it in the beginning of the function in the parenthesis.

Copy link
Author

@MythicArrow MythicArrow Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def prelu(x: ArrayLike, a: ArrayLike) -> Array:

Here I had used it.

a = jnp.asarray(a)

return jnp.maximum(x, 0) + a * jnp.minimum(x, 0) # This approach is more idiomatical because it avoids boolean masking.
#prelu.defjvps(
#lambda g, ans, x, a: lax.select(x >= 0, g, a * g), # f'(x) = 1 if x>= 0
#lambda g, ans, x, a: lax.select(x < 0, x * g, lax.full_like(g, 0)) # f'(x) = a if x<0
)


22 changes: 21 additions & 1 deletion tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,27 @@ def testRelu6Grad(self):
check_grads(nn.relu6, (-1.,), order=3, rtol=rtol)
self.assertAllClose(jax.grad(nn.relu6)(0.), 0., check_dtypes=False)
self.assertAllClose(jax.grad(nn.relu6)(6.), 0., check_dtypes=False)

def testPreluValue(self):
x = jnp.array([[1.0, -1.0, 2.0, -2.0]], dtype=jnp.float32)
a = jnp.array([0.1], dtype=jnp.float32)
expected = jnp.where(x >= 0, x, a * x)

y = nn.prelu(x, a)
self.assertTrue(jnp.allclose(y, expected, rtol=1e-6, atol=1e-6))

y_scalar = nn.prelu(x, 0.1)
self.assertTrue(jnp.allclose(y_scalar, expected, rtol=1e-6, atol=1e-6))

def testPreluGrad(self):
x = jnp.array([[1.0, -1.0, 0.0, -2.0]], dtype=jnp.float32)
a = jnp.array([0.1], dtype=jnp.float32)

def sum(x, a):
return jnp.sum(nn.prelu(x, a))

check_grads(sum, (x, a), order=1, modes=["fwd", "rev"])


def testSoftplusValue(self):
val = nn.softplus(89.)
self.assertAllClose(val, 89., check_dtypes=False)
Expand Down