Skip to content

Conversation

@MythicArrow
Copy link

Implementing the Parametric ReLU from the well-known paper "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification", which was proposed in ICCV 2015. This function introduces a parameter named "a" that is learnable, and it allows the function to adapt during training, potentially improving model accuracy and convergence compared to standard ReLU or Leaky ReLU functions.
Denoted as: f(x) = x if x>= 0 | ax if x < 0
ArXiv link: https://arxiv.org/abs/1502.01852#

Implementing the Parametric ReLU from the well-known paper "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification", which was proposed in ICCV 2015.
@MythicArrow
Copy link
Author

Could you please provide feedback on my implementation to confirm that it has been implemented correctly?

@jakevdp jakevdp self-assigned this Jun 2, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 2, 2025

Hi - thanks for the contribution! It looks like this would be a good contribution to jax.nn, but there are a number of changes we'd have to make: mainly, the a value should be an explicit parameter of the function, otherwise we wouldn't be able to differentiate with respect to it. With this in mind, the function should not take init or rng or num_parameters as arguments. Also, the implementation should probably modeled after that of the existing relu. We would also need to add tests for the new function, in tests/nn_test.py. Is this something you'd like to work on?

@MythicArrow
Copy link
Author

Yes of course sir, I would.

@MythicArrow
Copy link
Author

So would you willing to give me some time to work on it? Currently, I am a bit busy.

Made the function differentiable,defined "a" as an explicit argument, deleted unnecessary args, will add tests later
@MythicArrow
Copy link
Author

Would like to hear your feedback!

@MythicArrow
Copy link
Author

I will also add tests later.

Adding grad and value tests
@MythicArrow
Copy link
Author

Ok the pr combination has been done.

Args:
x (jnp.ndarray): Input tensor.
a (jnp.ndarray): Learnable parameter.
Copy link
Collaborator

@jakevdp jakevdp Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use two-space indentations. Also "learnable parameter" is not particularly descriptive – maybe say "slope parameter" or something similar?

Additionally: we should not list the type in parentheses after the argument name (see other docstrings in this file for examples of proper formatting).

Returns:
jnp.ndarray: Output tensor with the PReLU activation applied.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the jnp.ndarray:

Returns:
jnp.ndarray: Output tensor with the PReLU activation applied.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add an example or two here.

jnp.ndarray: Output tensor with the PReLU activation applied.
"""

x = jnp.asarray(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use check_arraylike before passing these to asarray

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had actually used it in the beginning of the function in the parenthesis.

Copy link
Author

@MythicArrow MythicArrow Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def prelu(x: ArrayLike, a: ArrayLike) -> Array:

Here I had used it.


x = jnp.asarray(x)
a = jnp.asarray(a)
return jnp.where(x >= 0, x, a * x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that this will be sufficient under autodiff: for example, relu has a custom JVP rule and some notes about its gradient behavior in the docstring. We should probably do similar here, but I'm not entirely sure what the best form to use is. a * relu(x) may be sufficient, but in other places it seems to be implemented as max(0, x) + a * min(0, x).

It would probably require constructing a few examples with gradients to see which form works in practice.

Copy link
Author

@MythicArrow MythicArrow Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The a*relu(x) may be incorrect because I checked the official paper and PyTorch implements it correctly. So should I do the same?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I checked the derivative of prelu and it was 1 for positive inputs and for negative inputs it was a, the learned slope.
So should I do it according to its grad?

tests/nn_test.py Outdated

def testPreluGrad(x, a):
return jnp.sum(nn.prelu(x, a))
check_grads(testprelugrad, (x, a), order=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test looks strange – copy-paste problem?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, after seeing the test file I realized that the tests for different purposes like grad and value were separated so I did the same.

@MythicArrow
Copy link
Author

Hello sir, have you been busy recently?

@DanisNone
Copy link
Contributor

@MythicArrow It seems that prelu is equivalent to leaky_relu?

@MythicArrow
Copy link
Author

Yeah it looks similar but there is a significant difference between them. Leaky ReLu contains "a" as a fixed constant but when it comes to PReLu its "a" is a learnable slope parameter, which is trained through backpropagation.

@DanisNone
Copy link
Contributor

You can pass a JAX array as the second argument to leaky_relu, and JAX will have no issues computing gradients through it.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 10, 2025

Yeah indeed @DanisNone – it looks like this is the same as leaky_relu. Sorry @MythicArrow, I should have noticed that earlier.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 10, 2025

I think given that, this PR can probably be closed.

@MythicArrow
Copy link
Author

MythicArrow commented Jun 10, 2025

But the PReLu's "a" is learnable not a constant.

@MythicArrow
Copy link
Author

I can make it eligible for being trained through backpropagation.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 10, 2025

But the PReLu's "a" is learnable not a constant.

That is true of leaky_relu as well.

@MythicArrow
Copy link
Author

But the PReLu's "a" is learnable not a constant.

That is true of leaky_relu as well.

Oh ok

@MythicArrow
Copy link
Author

Yeah indeed @DanisNone – it looks like this is the same as leaky_relu. Sorry @MythicArrow, I should have noticed that earlier.

No problem sir.

@MythicArrow
Copy link
Author

Ok then I will close this pr.

@MythicArrow
Copy link
Author

I checked the LeakyReLu's proportions and I see that it has a fixed value of "a" differing from the learnable one in the PReLu so would you merge the pr if I had edited the implementation in a way that it would have made the "a" learnable and trainable through backpropagation instead of the fixed one in LeakyReLu?

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 16, 2025

You mention LeakyReLu – are you talking about the function in the stax example library?

LeakyRelu = elementwise(leaky_relu)

If so, then we likely wouldn't accept such a contribution, because we are not adding new features to stax.

If you're thinking of jax.nn.leaky_relu, then I can't say I understand your request here. The negative_slope is a parameter just like any other, and you can take the gradient with respect to it, which means its value can be optimized given an appropriate loss function. That makes me think it qualifies as a "learnable parameter" – is there something more you're looking for?

@MythicArrow
Copy link
Author

I had checked the original paper of LeakyReLu and its "a" was a fixed constant like 0.01 and wasn't learned. But in this case the jax's LeakyReLu func is like PReLu so I think there is no need for implementing PReLu.

@MythicArrow
Copy link
Author

Sir, I have researched the PReLu and found out that it indeed has a learnable parameter which is included in the model's weights. So if I define it as a parameter then it will be automatically learnable differing from the LeakyReLu. Does JAX allow to define 'a' as a model parameter?

@MythicArrow MythicArrow reopened this Dec 11, 2025
Refactor PReLU implementation to use jnp.maximum and jnp.minimum for better idiomatic expression.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants