44using ChainRules
55using ChainRulesCore
66using Symbolics: Symbolics, @variables , @rule , unwrap, isdiv
7+ import SymbolicUtils
78using SymbolicUtils. Code: toexpr
89using MacroTools
910using MacroTools: prewalk, postwalk
@@ -13,19 +14,58 @@ Pick a strategy for raising the derivative of a function.
1314If the derivative is like 1 over something, raise with the division rule;
1415otherwise, raise with the multiplication rule.
1516"""
16- function get_term_raiser (func)
17- @variables z
18- r1 = @rule - 1 * (1 / ~ x) => (- 1 ) / ~ x
19- der = frule ((NoTangent (), true ), func, z)[2 ]
20- term = unwrap (der)
21- maybe_rewrite = r1 (term)
22- if maybe_rewrite != = nothing
23- term = maybe_rewrite
17+ function get_term_raiser end
18+
19+ @static if pkgversion (Symbolics) < v " 7"
20+ function get_term_raiser (func)
21+ @variables z
22+ r1 = @rule - 1 * (1 / ~ x) => (- 1 ) / ~ x
23+ der = frule ((NoTangent (), true ), func, z)[2 ]
24+ term = unwrap (der)
25+ maybe_rewrite = r1 (term)
26+ if maybe_rewrite != = nothing
27+ term = maybe_rewrite
28+ end
29+ if isdiv (term) && (term. num == 1 || term. num == - 1 )
30+ term. den * term. num, raiseinv
31+ else
32+ term, raise
33+ end
34+ end
35+ else
36+ const COMMON_Z = only (@variables z)
37+ const FALLBACK_RULE = (@rule - 1 * (1 / ~ x) => (- 1 ) / ~ x)
38+
39+ function is_plusminus_one (@nospecialize (x))
40+ if x isa Int
41+ return x == 1 || x == - 1
42+ elseif x isa Int32
43+ return x == 1 || x == - 1
44+ elseif x isa Float64
45+ return x == 1 || x == - 1
46+ elseif x isa Float32
47+ return x == 1 || x == - 1
48+ elseif x isa Number
49+ return (x == 1 ):: Bool || (x == - 1 ):: Bool
50+ else
51+ return false
52+ end
2453 end
25- if isdiv (term) && (term. num == 1 || term. num == - 1 )
26- term. den * term. num, raiseinv
27- else
28- term, raise
54+
55+ function get_term_raiser (func)
56+ der = frule ((NoTangent (), true ), func, COMMON_Z)[2 ]
57+ term = unwrap (der)
58+ maybe_rewrite = FALLBACK_RULE (term)
59+ if maybe_rewrite != = nothing
60+ term = maybe_rewrite
61+ end
62+ if isdiv (term)
63+ num, den = SymbolicUtils. arguments (term)
64+ if SymbolicUtils. isconst (num) && is_plusminus_one (SymbolicUtils. unwrap_const (num))
65+ return num * den, raiseinv
66+ end
67+ end
68+ return term, raise
2969 end
3070end
3171
0 commit comments