5555
5656ProjectTo (:: T ) where {T <: TaylorScalar } = ProjectTo {T} ()
5757(p:: ProjectTo{T} )(x:: T ) where {T <: TaylorScalar } = x
58- ProjectTo (x:: AbstractArray{T} ) where {T <: TaylorScalar } = ProjectTo {AbstractArray} (; element= ProjectTo (zero (T)), axes= axes (x))
58+ function ProjectTo (x:: AbstractArray{T} ) where {T <: TaylorScalar }
59+ ProjectTo {AbstractArray} (; element = ProjectTo (zero (T)), axes = axes (x))
60+ end
5961(p:: ProjectTo{AbstractArray{T}} )(x:: AbstractArray{T} ) where {T <: TaylorScalar } = x
6062accum_sum (xs:: AbstractArray{T} ; dims = :) where {T <: TaylorScalar } = sum (xs, dims = dims)
6163
62- TaylorNumeric{T<: TaylorScalar } = Union{T, AbstractArray{<: T }}
64+ TaylorNumeric{T <: TaylorScalar } = Union{T, AbstractArray{<: T }}
6365
64- @adjoint broadcasted (:: typeof (+ ), xs:: Union{Numeric, TaylorNumeric} ...) = broadcast (+ , xs... ), ȳ -> (nothing , map (x -> unbroadcast (x, ȳ), xs)... )
66+ @adjoint function broadcasted (:: typeof (+ ), xs:: Union{Numeric, TaylorNumeric} ...)
67+ broadcast (+ , xs... ), ȳ -> (nothing , map (x -> unbroadcast (x, ȳ), xs)... )
68+ end
6569
66- struct TaylorOneElement{T,N,I, A} <: AbstractArray{T,N}
70+ struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N}
6771 val:: T
6872 ind:: I
6973 axes:: A
70- TaylorOneElement (val:: T , ind:: I , axes:: A ) where {T<: TaylorScalar , I<: NTuple{N,Int} , A<: NTuple{N,AbstractUnitRange} } where {N} = new {T,N,I,A} (val, ind, axes)
74+ function TaylorOneElement (val:: T , ind:: I ,
75+ axes:: A ) where {T <: TaylorScalar , I <: NTuple{N, Int} ,
76+ A <: NTuple{N, AbstractUnitRange} } where {N}
77+ new {T, N, I, A} (val, ind, axes)
78+ end
7179end
7280
7381Base. size (A:: TaylorOneElement ) = map (length, A. axes)
7482Base. axes (A:: TaylorOneElement ) = A. axes
75- Base. getindex (A:: TaylorOneElement{T,N} , i:: Vararg{Int,N} ) where {T,N} = ifelse (i== A. ind, A. val, zero (T))
83+ function Base. getindex (A:: TaylorOneElement{T, N} , i:: Vararg{Int, N} ) where {T, N}
84+ ifelse (i == A. ind, A. val, zero (T))
85+ end
7686
77- ∇getindex (x:: AbstractArray{T, N} , inds) where {T <: TaylorScalar , N} = dy -> begin
78- dx = TaylorOneElement (dy, inds, axes (x))
79- return (_project (x, dx), map (_-> nothing , inds)... )
87+ function ∇getindex (x:: AbstractArray{T, N} , inds) where {T <: TaylorScalar , N}
88+ dy -> begin
89+ dx = TaylorOneElement (dy, inds, axes (x))
90+ return (_project (x, dx), map (_ -> nothing , inds)... )
91+ end
8092end
8193
8294@generated function mul_adjoint (Ω:: TaylorScalar{T, N} , x:: TaylorScalar{T, N} ) where {T, N}
@@ -93,12 +105,14 @@ rrule(::typeof(*), x::TaylorScalar) = rrule(identity, x)
93105function rrule (:: typeof (* ), x:: TaylorScalar , y:: TaylorScalar )
94106 function times_pullback2 (Ω̇)
95107 ΔΩ = unthunk (Ω̇)
96- return (NoTangent (), ProjectTo (x)(mul_adjoint (ΔΩ, y)), ProjectTo (y)(mul_adjoint (ΔΩ, x)))
108+ return (NoTangent (), ProjectTo (x)(mul_adjoint (ΔΩ, y)),
109+ ProjectTo (y)(mul_adjoint (ΔΩ, x)))
97110 end
98111 return x * y, times_pullback2
99112end
100113
101- function rrule (:: typeof (* ), x:: TaylorScalar , y:: TaylorScalar , z:: TaylorScalar , more:: TaylorScalar... )
114+ function rrule (:: typeof (* ), x:: TaylorScalar , y:: TaylorScalar , z:: TaylorScalar ,
115+ more:: TaylorScalar... )
102116 Ω2, back2 = rrule (* , x, y)
103117 Ω3, back3 = rrule (* , Ω2, z)
104118 Ω4, back4 = rrule (* , Ω3, more... )
0 commit comments