-
|
I was creating a new package with Jax (and its ecosystem) and was writing tests to check that the gradients are correct. However, this results in an Can anyone tell me what is going wrong with this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
|
check_grads(dummy_loss, (init_params,), order=2, eps=1e-6) |
Beta Was this translation helpful? Give feedback.
If I run your example in 64-bit precision with
eps=1E-6, I find that the gradients match. In float32 precision, if I use a smallerepsthen the numerical result diverges.This makes me think that your function has a fast-varying second derivative, which makes the numerical gradient inaccurate, but nevertheless the analytic gradient is probably producing the correct value.