-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Implementing PReLu #29147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implementing PReLu #29147
Conversation
Implementing the Parametric ReLU from the well-known paper "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification", which was proposed in ICCV 2015.
|
Could you please provide feedback on my implementation to confirm that it has been implemented correctly? |
|
Hi - thanks for the contribution! It looks like this would be a good contribution to |
|
Yes of course sir, I would. |
|
So would you willing to give me some time to work on it? Currently, I am a bit busy. |
Made the function differentiable,defined "a" as an explicit argument, deleted unnecessary args, will add tests later
|
Would like to hear your feedback! |
|
I will also add tests later. |
Adding grad and value tests
|
Ok the pr combination has been done. |
jax/_src/nn/functions.py
Outdated
| Args: | ||
| x (jnp.ndarray): Input tensor. | ||
| a (jnp.ndarray): Learnable parameter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use two-space indentations. Also "learnable parameter" is not particularly descriptive – maybe say "slope parameter" or something similar?
Additionally: we should not list the type in parentheses after the argument name (see other docstrings in this file for examples of proper formatting).
jax/_src/nn/functions.py
Outdated
| Returns: | ||
| jnp.ndarray: Output tensor with the PReLU activation applied. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the jnp.ndarray:
| Returns: | ||
| jnp.ndarray: Output tensor with the PReLU activation applied. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add an example or two here.
| jnp.ndarray: Output tensor with the PReLU activation applied. | ||
| """ | ||
|
|
||
| x = jnp.asarray(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use check_arraylike before passing these to asarray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had actually used it in the beginning of the function in the parenthesis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def prelu(x: ArrayLike, a: ArrayLike) -> Array:
Here I had used it.
jax/_src/nn/functions.py
Outdated
|
|
||
| x = jnp.asarray(x) | ||
| a = jnp.asarray(a) | ||
| return jnp.where(x >= 0, x, a * x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that this will be sufficient under autodiff: for example, relu has a custom JVP rule and some notes about its gradient behavior in the docstring. We should probably do similar here, but I'm not entirely sure what the best form to use is. a * relu(x) may be sufficient, but in other places it seems to be implemented as max(0, x) + a * min(0, x).
It would probably require constructing a few examples with gradients to see which form works in practice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The a*relu(x) may be incorrect because I checked the official paper and PyTorch implements it correctly. So should I do the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, I checked the derivative of prelu and it was 1 for positive inputs and for negative inputs it was a, the learned slope.
So should I do it according to its grad?
tests/nn_test.py
Outdated
|
|
||
| def testPreluGrad(x, a): | ||
| return jnp.sum(nn.prelu(x, a)) | ||
| check_grads(testprelugrad, (x, a), order=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test looks strange – copy-paste problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, after seeing the test file I realized that the tests for different purposes like grad and value were separated so I did the same.
For the new diff type see the previous commit
|
Hello sir, have you been busy recently? |
|
@MythicArrow It seems that |
|
Yeah it looks similar but there is a significant difference between them. Leaky ReLu contains "a" as a fixed constant but when it comes to PReLu its "a" is a learnable slope parameter, which is trained through backpropagation. |
|
You can pass a JAX array as the second argument to |
|
Yeah indeed @DanisNone – it looks like this is the same as |
|
I think given that, this PR can probably be closed. |
|
But the PReLu's "a" is learnable not a constant. |
|
I can make it eligible for being trained through backpropagation. |
That is true of |
Oh ok |
No problem sir. |
|
Ok then I will close this pr. |
|
I checked the LeakyReLu's proportions and I see that it has a fixed value of "a" differing from the learnable one in the PReLu so would you merge the pr if I had edited the implementation in a way that it would have made the "a" learnable and trainable through backpropagation instead of the fixed one in LeakyReLu? |
|
You mention jax/jax/example_libraries/stax.py Line 158 in 9678a76
If so, then we likely wouldn't accept such a contribution, because we are not adding new features to stax. If you're thinking of |
|
I had checked the original paper of LeakyReLu and its "a" was a fixed constant like 0.01 and wasn't learned. But in this case the jax's LeakyReLu func is like PReLu so I think there is no need for implementing PReLu. |
|
Sir, I have researched the PReLu and found out that it indeed has a learnable parameter which is included in the model's weights. So if I define it as a parameter then it will be automatically learnable differing from the LeakyReLu. Does JAX allow to define 'a' as a model parameter? |
Refactor PReLU implementation to use jnp.maximum and jnp.minimum for better idiomatic expression.
Implementing the Parametric ReLU from the well-known paper "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification", which was proposed in ICCV 2015. This function introduces a parameter named "a" that is learnable, and it allows the function to adapt during training, potentially improving model accuracy and convergence compared to standard ReLU or Leaky ReLU functions.
Denoted as: f(x) = x if x>= 0 | ax if x < 0
ArXiv link: https://arxiv.org/abs/1502.01852#