-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Implementing PReLu #29147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implementing PReLu #29147
Changes from all commits
24fe65c
c568f7f
d6a86de
ca38642
61e4851
f57a119
a22873c
86693c5
055a883
7d8ed8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1472,26 +1472,40 @@ def scaled_dot_general( | |
| ) | ||
|
|
||
| return out | ||
| @jax.jit | ||
| @custom_jvp | ||
| def prelu(x: ArrayLike, a: ArrayLike) -> Array: | ||
| """ | ||
| Applies the PReLU (Parametric ReLU) activation function element-wise. Similar to ReLU, but includes a slope parameter a that controls the slope for negative inputs, allowing the network to adaptively retain some negative information and prevent dead neurons. | ||
|
|
||
| @custom_derivatives.custom_jvp | ||
| @api.jit | ||
| def log1mexp(x: ArrayLike) -> Array: | ||
| r"""Numerically stable calculation of :math:`\log(1 - \exp(-x))`. | ||
|
|
||
| This function is undefined for :math:`x < 0`. | ||
|
|
||
| Based on `TensorFlow's implementation <https://www.tensorflow.org/probability/api_docs/python/tfp/math/log1mexp>`_. | ||
| Args: | ||
| x (Array): Input tensor. | ||
|
|
||
|
|
||
| a (Array): Slope parameter. | ||
|
|
||
|
|
||
| Returns: | ||
| Output tensor with the PReLU activation applied. | ||
|
|
||
| If x>= 0 it returns x. | ||
| Otherwise, it returns x*a. | ||
| Examples: | ||
| x = jnp.array([0.1,0.2,3,-2]) | ||
| a = 0.1 (Placeholder scalar value, the a is trained through backpropagation) | ||
| output: [ 0.1 0.2 3. -0.2] | ||
|
|
||
| References: | ||
| .. [1] Martin Mächler. `Accurately Computing log(1 − exp(−|a|)) Assessed by the Rmpfr package. | ||
| <https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf>`_. | ||
| """ | ||
| x = numpy_util.ensure_arraylike("log1mexp", x) | ||
| c = jnp.log(2.0) | ||
| return jnp.where( | ||
| x < c, | ||
| jnp.log(-jnp.expm1(-x)), | ||
| jnp.log1p(-jnp.exp(-x)), | ||
| ) | ||
|
|
||
| log1mexp.defjvps(lambda g, ans, x: g / jnp.expm1(x)) | ||
| x = jnp.asarray(x) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had actually used it in the beginning of the function in the parenthesis.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I had used it. |
||
| a = jnp.asarray(a) | ||
|
|
||
| return jnp.maximum(x, 0) + a * jnp.minimum(x, 0) # This approach is more idiomatical because it avoids boolean masking. | ||
| #prelu.defjvps( | ||
| #lambda g, ans, x, a: lax.select(x >= 0, g, a * g), # f'(x) = 1 if x>= 0 | ||
| #lambda g, ans, x, a: lax.select(x < 0, x * g, lax.full_like(g, 0)) # f'(x) = a if x<0 | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add an example or two here.