Skip to content

Commit 00c4b77

Browse files
AayushSabharwalChrisRackauckas
authored andcommitted
feat: support Symbolics@7
1 parent 53b5ba1 commit 00c4b77

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ MacroTools = "0.5"
2727
NNlib = "0.9"
2828
SpecialFunctions = "2"
2929
SymbolicUtils = "3, 4"
30-
Symbolics = "6"
30+
Symbolics = "6, 7"
3131
Zygote = "0.6, 0.7"
3232
julia = "1.10"

src/utils.jl

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using ChainRules
55
using ChainRulesCore
66
using Symbolics: Symbolics, @variables, @rule, unwrap, isdiv
7+
import SymbolicUtils
78
using SymbolicUtils.Code: toexpr
89
using MacroTools
910
using MacroTools: prewalk, postwalk
@@ -13,19 +14,58 @@ Pick a strategy for raising the derivative of a function.
1314
If the derivative is like 1 over something, raise with the division rule;
1415
otherwise, raise with the multiplication rule.
1516
"""
16-
function get_term_raiser(func)
17-
@variables z
18-
r1 = @rule -1 * (1 / ~x) => (-1) / ~x
19-
der = frule((NoTangent(), true), func, z)[2]
20-
term = unwrap(der)
21-
maybe_rewrite = r1(term)
22-
if maybe_rewrite !== nothing
23-
term = maybe_rewrite
17+
function get_term_raiser end
18+
19+
@static if pkgversion(Symbolics) < v"7"
20+
function get_term_raiser(func)
21+
@variables z
22+
r1 = @rule -1 * (1 / ~x) => (-1) / ~x
23+
der = frule((NoTangent(), true), func, z)[2]
24+
term = unwrap(der)
25+
maybe_rewrite = r1(term)
26+
if maybe_rewrite !== nothing
27+
term = maybe_rewrite
28+
end
29+
if isdiv(term) && (term.num == 1 || term.num == -1)
30+
term.den * term.num, raiseinv
31+
else
32+
term, raise
33+
end
34+
end
35+
else
36+
const COMMON_Z = only(@variables z)
37+
const FALLBACK_RULE = (@rule -1 * (1 / ~x) => (-1) / ~x)
38+
39+
function is_plusminus_one(@nospecialize(x))
40+
if x isa Int
41+
return x == 1 || x == -1
42+
elseif x isa Int32
43+
return x == 1 || x == -1
44+
elseif x isa Float64
45+
return x == 1 || x == -1
46+
elseif x isa Float32
47+
return x == 1 || x == -1
48+
elseif x isa Number
49+
return (x == 1)::Bool || (x == -1)::Bool
50+
else
51+
return false
52+
end
2453
end
25-
if isdiv(term) && (term.num == 1 || term.num == -1)
26-
term.den * term.num, raiseinv
27-
else
28-
term, raise
54+
55+
function get_term_raiser(func)
56+
der = frule((NoTangent(), true), func, COMMON_Z)[2]
57+
term = unwrap(der)
58+
maybe_rewrite = FALLBACK_RULE(term)
59+
if maybe_rewrite !== nothing
60+
term = maybe_rewrite
61+
end
62+
if isdiv(term)
63+
num, den = SymbolicUtils.arguments(term)
64+
if SymbolicUtils.isconst(num) && is_plusminus_one(SymbolicUtils.unwrap_const(num))
65+
return num * den, raiseinv
66+
end
67+
end
68+
return term, raise
2969
end
3070
end
3171

0 commit comments

Comments
 (0)