Removing linear dependencies in orthogonalization #18982
-
|
Hello, I've started using JAX for my research and have run into the trouble of performing a symmetric orthogonalization and getting what I expect for the derivative. When computing the derivative, I get for the evals but for the sqrtm matrix I get which seems to neglect the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
|
Thanks for the question! Just to make sure understand: you're surprised to see values smaller than |
Beta Was this translation helpful? Give feedback.
Indeed, when differentiating the
jnp.where, the primal values of the boolean arrayabs(evals) > cutoffare used to filter both the primal and tangent values. It's the same idea as if we were differentiatinglambda x: x **2 if x > 0 else xat primal valuex=1.0but with tangent valuex_dot=-1.0: we're linearizing around the primal point and so we want to switch based on the primal value only, then have the tangent value follow along (i.e. to go through thex ** 2function) rather than taking its own path. In this case we're writing ajnp.whereinstead of anif, but it's the same logic (like differentiatinglambda x: jnp.where(x > 0, x ** 2, x)).So, super concretely, when we differentiate
l…