Skip to content

Commit 6d62163

Browse files
authored
Make external eqs tests work again (#52)
1 parent adb7e39 commit 6d62163

File tree

15 files changed

+210
-173
lines changed

15 files changed

+210
-173
lines changed

Manifest.toml

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Cthulhu = {rev = "master", url = "https://github.com/JuliaDebug/Cthulhu.jl.git"}
5050
DifferentiationInterface = {rev = "main", subdir = "DifferentiationInterface", url = "https://github.com/Keno/DifferentiationInterface.jl"}
5151
Diffractor = {rev = "main", url = "https://github.com/JuliaDiff/Diffractor.jl.git"}
5252
SimpleNonlinearSolve = {rev = "master", subdir = "lib/SimpleNonlinearSolve", url = "https://github.com/SciML/NonlinearSolve.jl.git"}
53-
StateSelection = {rev = "main", url = "https://github.com/JuliaComputing/StateSelection.jl"}
53+
StateSelection = {rev = "main", url = "https://github.com/JuliaComputing/StateSelection.jl.git"}
5454

5555
[compat]
5656
Accessors = "0.1.36"

src/analysis/cache.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ struct DAEIPOResult
4848
varkinds::Vector{Union{Intrinsics.VarKind, Nothing}}
4949
eqkinds::Vector{Union{Intrinsics.EqKind, Nothing}}
5050
# TODO: Chain these rather than copying them
51-
warnings::Vector{UnsupportedIRException}
51+
warnings::Vector{BadDAECompilerInputException}
5252
end
5353

5454
struct UncompilableIPOResult
55-
warnings::Vector{UnsupportedIRException}
56-
error::UnsupportedIRException
55+
warnings::Vector{BadDAECompilerInputException}
56+
error::BadDAECompilerInputException
5757
end
5858

5959
function add_equation_row!(graph, solvable_graph, ieq::Int, inc::Incidence)
@@ -92,7 +92,9 @@ function make_structure_from_ipo(ipo::DAEIPOResult)
9292
graph = BipartiteGraph(neqs, nvars)
9393
solvable_graph = BipartiteGraph(neqs, nvars)
9494

95-
for (ieq, inc) in enumerate(ipo.total_incidence)
95+
for ieq in 1:length(ipo.total_incidence)
96+
isassigned(ipo.total_incidence, ieq) || continue
97+
inc = ipo.total_incidence[ieq]
9698
add_equation_row!(graph, solvable_graph, ieq, inc)
9799
end
98100

src/analysis/flattening.jl

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# TODO: We should unify this function with _make_argument_lattice_elem to ensure consistency
21
function _flatten_parameter!(𝕃, compact, argtypes, ntharg, line)
32
list = Any[]
43
for (argn, argt) in enumerate(argtypes)
@@ -8,6 +7,8 @@ function _flatten_parameter!(𝕃, compact, argtypes, ntharg, line)
87
continue
98
elseif Base.isprimitivetype(argt) || isa(argt, Incidence)
109
push!(list, ntharg(argn))
10+
elseif argt === equation || isa(argt, Eq)
11+
continue
1112
elseif isa(argt, Type) && argt <: Intrinsics.AbstractScope
1213
continue
1314
elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct))
@@ -36,64 +37,79 @@ function flatten_parameter!(𝕃, compact, argtypes, ntharg, line)
3637
end
3738

3839
# Needs to match flatten_arguments!
39-
function process_template_arg!(𝕃, coeffs, eq_mapping, applied_scopes, argt, template_argt, offset=0)
40+
function process_template_arg!(𝕃, coeffs, eq_mapping, applied_scopes, argt, template_argt, offset=0, eqoffset=0)::Pair{Int, Int}
4041
if isa(template_argt, Const)
4142
@assert isa(argt, Const) && argt.val === template_argt.val
42-
return offset
43+
return Pair{Int, Int}(offset, eqoffset)
4344
elseif Base.issingletontype(template_argt)
4445
@assert isa(template_argt, Type) && argt.instance === template_argt.instance
45-
return offset
46+
return Pair{Int, Int}(offset, eqoffset)
4647
elseif Base.isprimitivetype(template_argt)
4748
coeffs[offset+1] = argt
48-
return offset + 1
49+
return Pair{Int, Int}(offset + 1, eqoffset)
50+
elseif template_argt === equation
51+
eq_mapping[eqoffset+1] = argt.id
52+
return Pair{Int, Int}(offset, eqoffset + 1)
4953
elseif isabstracttype(template_argt) || ismutabletype(template_argt) || (!isa(template_argt, DataType) && !isa(template_argt, PartialStruct))
50-
return offset
54+
return Pair{Int, Int}(offset, eqoffset)
5155
else
5256
if !isa(template_argt, PartialStruct) && Base.datatype_fieldcount(template_argt) === nothing
53-
return offset
57+
return Pair{Int, Int}(offset, eqoffset)
5458
end
5559
template_fields = isa(template_argt, PartialStruct) ? template_argt.fields : collect(fieldtypes(template_argt))
5660
return process_template!(𝕃, coeffs, eq_mapping, applied_scopes, Any[Compiler.getfield_tfunc(𝕃, argt, Const(i)) for i = 1:length(template_fields)], template_fields, offset)
5761
end
5862
end
5963

60-
function process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, template_argtypes, offset=0)
64+
function process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, template_argtypes, offset=0, eqoffset=0)
6165
@assert length(argtypes) == length(template_argtypes)
6266
for (i, template_arg) in enumerate(template_argtypes)
63-
offset = process_template_arg!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes[i], template_arg, offset)
67+
(offset, eqoffset) = process_template_arg!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes[i], template_arg, offset)
6468
end
65-
return offset
69+
return Pair{Int, Int}(offset, eqoffset)
70+
end
71+
72+
struct TransformedArg
73+
ssa::Any
74+
offset::Int
75+
eqoffset::Int
76+
TransformedArg(@nospecialize(arg), new_offset::Int, new_eqoffset::Int) = new(arg, new_offset, new_eqoffset)
6677
end
6778

68-
function flatten_argument!(compact::Compiler.IncrementalCompact, argt, offset, argtypes::Vector{Any})::Pair{Any, Int}
69-
@assert !isa(argt, Incidence)
79+
function flatten_argument!(compact::Compiler.IncrementalCompact, @nospecialize(argt), offset::Int, eqoffset::Int, argtypes::Vector{Any})::TransformedArg
80+
@assert !isa(argt, Incidence) && !isa(argt, Eq)
7081
if isa(argt, Const)
71-
return Pair{Any, Int}(argt.val, offset)
82+
return TransformedArg(argt.val, offset, eqoffset)
7283
elseif Base.issingletontype(argt)
73-
return Pair{Any, Int}(argt.instance, offset)
84+
return TransformedArg(argt.instance, offset, eqoffset)
7485
elseif Base.isprimitivetype(argt)
7586
push!(argtypes, argt)
76-
return Pair{Any, Int}(Argument(offset+1), offset+1)
87+
return TransformedArg(Argument(offset+1), offset+1, eqoffset)
88+
elseif argt === equation
89+
ssa = insert_node_here!(compact, NewInstruction(Expr(:invoke, nothing, InternalIntrinsics.external_equation), Eq(eqoffset+1), compact[Compiler.OldSSAValue(1)][:line]))
90+
return TransformedArg(ssa, offset, eqoffset+1)
7791
elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct))
7892
ssa = insert_node_here!(compact, NewInstruction(Expr(:call, error, "Cannot IPO model arg type $argt"), Union{}, compact[Compiler.OldSSAValue(1)][:line]))
79-
return Pair{Any, Int}(ssa, offset)
93+
return TransformedArg(ssa, -1, eqoffset)
8094
else
8195
if !isa(argt, PartialStruct) && Base.datatype_fieldcount(argt) === nothing
8296
ssa = insert_node_here!(compact, NewInstruction(Expr(:call, error, "Cannot IPO model arg type $argt"), Union{}, compact[Compiler.OldSSAValue(1)][:line]))
83-
return Pair{Any, Int}(ssa, offset)
97+
return TransformedArg(ssa, -1, eqoffset)
8498
end
85-
(args, _, offset) = flatten_arguments!(compact, isa(argt, PartialStruct) ? argt.fields : fieldtypes(argt), offset, argtypes)
99+
(args, _, offset) = flatten_arguments!(compact, isa(argt, PartialStruct) ? argt.fields : collect(Any, fieldtypes(argt)), offset, eqoffset, argtypes)
100+
offset == -1 && return TransformedArg(ssa, -1, eqoffset)
86101
this = Expr(:new, isa(argt, PartialStruct) ? argt.typ : argt, args...)
87102
ssa = insert_node_here!(compact, NewInstruction(this, argt, compact[Compiler.OldSSAValue(1)][:line]))
88-
return Pair{Any, Int}(ssa, offset)
103+
return TransformedArg(ssa, offset, eqoffset)
89104
end
90105
end
91106

92-
function flatten_arguments!(compact::Compiler.IncrementalCompact, argtypes, offset=0, new_argtypes::Vector{Any} = Any[])
107+
function flatten_arguments!(compact::Compiler.IncrementalCompact, argtypes::Vector{Any}, offset::Int=0, eqoffset::Int=0, new_argtypes::Vector{Any} = Any[])
93108
args = Any[]
94109
for argt in argtypes
95-
(ssa, offset) = flatten_argument!(compact, argt, offset, new_argtypes)
110+
(; ssa, offset, eqoffset) = flatten_argument!(compact, argt, offset, eqoffset, new_argtypes)
111+
offset == -1 && break
96112
push!(args, ssa)
97113
end
98-
return (args, new_argtypes, offset)
114+
return (args, new_argtypes, offset, eqoffset)
99115
end

src/analysis/ipoincidence.jl

Lines changed: 11 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,36 @@ struct CalleeMapping
66
end
77

88
struct CallerMappingState
9-
result::DAEIPOResult
9+
callee_result::DAEIPOResult
1010
caller_var_to_diff::DiffGraph
1111
caller_varclassification::Vector{VarEqClassification}
1212
caller_varkind::Union{Vector{Intrinsics.VarKind}, Nothing}
1313
caller_eqclassification::Vector{VarEqClassification}
14+
caller_eqkinds::Union{Vector{Intrinsics.EqKind}, Nothing}
1415
end
1516

16-
function compute_missing_coeff!(coeffs, (;result, caller_var_to_diff, caller_varclassification, caller_varkind)::CallerMappingState, v)
17+
function compute_missing_coeff!(coeffs, (;callee_result, caller_var_to_diff, caller_varclassification, caller_varkind)::CallerMappingState, v)
1718
# First find the rootvar, and if we already have a coeff for it
1819
# apply the derivatives.
1920
ndiffs = 0
20-
calle_inv = invview(result.var_to_diff)
21+
calle_inv = invview(callee_result.var_to_diff)
2122
while calle_inv[v] !== nothing && !isassigned(coeffs, v)
2223
ndiffs += 1
2324
v = calle_inv[v]
2425
end
2526

2627
if !isassigned(coeffs, v)
27-
@assert v > result.nexternalargvars # Arg vars should have already been mapped
28+
@assert v > callee_result.nexternalargvars # Arg vars should have already been mapped
2829
# Reached the root and it's an internal variable. We need to allocate
2930
# it in the caller now
3031
coeffs[v] = Incidence(add_vertex!(caller_var_to_diff))
31-
push!(caller_varclassification, result.varclassification[v] == External ? Owned : CalleeInternal)
32-
push!(caller_varkind, result.varkinds[v])
32+
push!(caller_varclassification, callee_result.varclassification[v] == External ? Owned : CalleeInternal)
33+
push!(caller_varkind, callee_result.varkinds[v])
3334
end
3435
thisinc = coeffs[v]
3536

3637
for _ = 1:ndiffs
37-
dv = result.var_to_diff[v]
38+
dv = callee_result.var_to_diff[v]
3839
coeffs[dv] = structural_inc_ddt(caller_var_to_diff, caller_varclassification, caller_varkind, thisinc)
3940
v = dv
4041
end
@@ -143,10 +144,9 @@ end
143144
function apply_linear_incidence(𝕃, ret::Eq, caller::CallerMappingState, mapping::CalleeMapping)
144145
eq_mapping = mapping.eqs[ret.id]
145146
if eq_mapping == 0
146-
error("I removed these from StructuralRefiner for conceptual reasons - if we hit these, lets revisit")
147-
#push!(caller_eqclassification, Owned)
148-
#push!(caller_eqkinds, result.eqkinds[ret.id])
149-
mapping.eqs[ret.id] = eq_mapping = length(caller_eqclassification)
147+
push!(caller.caller_eqclassification, Owned)
148+
push!(caller.caller_eqkinds, caller.callee_result.eqkinds[ret.id])
149+
mapping.eqs[ret.id] = eq_mapping = length(caller.caller_eqclassification)
150150
end
151151
return Eq(eq_mapping)
152152
end
@@ -172,47 +172,3 @@ struct MappingInfo <: Compiler.CallInfo
172172
result::DAEIPOResult
173173
mapping::CalleeMapping
174174
end
175-
176-
function _make_argument_lattice_elem(𝕃, which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
177-
if isa(argt, Const)
178-
#@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
179-
return argt
180-
elseif isa(argt, Type) && argt <: Intrinsics.AbstractScope
181-
return PartialScope(add_scope!(which))
182-
elseif isa(argt, Type) && argt == equation
183-
return Eq(add_equation!(which))
184-
elseif is_non_incidence_type(argt)
185-
return argt
186-
elseif Compiler.isprimitivetype(argt)
187-
inc = Incidence(add_variable!(which))
188-
return argt === Float64 ? inc : Incidence(argt, inc.row)
189-
elseif isa(argt, PartialStruct)
190-
return PartialStruct(𝕃, argt.typ, Any[make_argument_lattice_elem(𝕃, which, f, add_variable!, add_equation!, add_scope!) for f in argt.fields])
191-
elseif isabstracttype(argt) || ismutabletype(argt) || !isa(argt, DataType)
192-
return nothing
193-
else
194-
fields = Any[]
195-
any = false
196-
# TODO: This doesn't handle recursion
197-
if Base.datatype_fieldcount(argt) === nothing
198-
return nothing
199-
end
200-
for i = 1:length(fieldtypes(argt))
201-
# TODO: Can we make this lazy?
202-
ft = fieldtype(argt, i)
203-
mft = _make_argument_lattice_elem(𝕃, which, ft, add_variable!, add_equation!, add_scope!)
204-
if mft === nothing
205-
push!(fields, Incidence(ft))
206-
else
207-
any = true
208-
push!(fields, mft)
209-
end
210-
end
211-
return any ? PartialStruct(𝕃, argt, fields) : nothing
212-
end
213-
end
214-
215-
function make_argument_lattice_elem(𝕃, which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
216-
mft = _make_argument_lattice_elem(𝕃, which, argt, add_variable!, add_equation!, add_scope!)
217-
mft === nothing ? Incidence(argt) : mft
218-
end

src/analysis/lattice.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ const linear_time_and_state_dependent = Linearity(nonlinear = false)
8181
"The variable is used nonlinearly, with a possible dependence on time and other states."
8282
const nonlinear = Linearity()
8383

84+
function Base.show(io::IO, lin::Linearity)
85+
if lin === nonlinear
86+
print(io, "nonlinear")
87+
else
88+
invoke(Base.show, Tuple{IO, Any}, io, lin)
89+
end
90+
end
91+
8492
join_linearity(a::Linearity, b::Real) = a
8593
join_linearity(a::Real, b::Linearity) = b
8694
join_linearity(a::Real, b::Real) = a == b ? a : linear

src/analysis/refiner.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ struct StructuralRefiner <: Compiler.AbstractInterpreter
1010
var_to_diff::DiffGraph
1111
varkinds::Vector{Intrinsics.VarKind}
1212
varclassification::Vector{VarEqClassification}
13+
eqkinds::Vector{Intrinsics.EqKind}
14+
eqclassification::Vector{VarEqClassification}
1315
end
1416

1517
struct StructureCache; end
@@ -59,7 +61,7 @@ Compiler.cache_owner(::StructuralRefiner) = StructureCache()
5961
argtypes = Compiler.collect_argtypes(interp, stmt.args, Compiler.StatementState(nothing, false), irsv)[2:end]
6062
mapping = CalleeMapping(Compiler.optimizer_lattice(interp), argtypes, callee_codeinst, callee_result, callee_codeinst.inferred.ir.argtypes)
6163
new_rt = apply_linear_incidence(Compiler.optimizer_lattice(interp), callee_result.extended_rt,
62-
CallerMappingState(callee_result, interp.var_to_diff, interp.varclassification, interp.varkinds, VarEqClassification[]), mapping)
64+
CallerMappingState(callee_result, interp.var_to_diff, interp.varclassification, interp.varkinds, interp.eqclassification, interp.eqkinds), mapping)
6365

6466
# Remember this mapping, both for performance of not having to recompute it
6567
# and because we may have assigned caller variables to internal variables

0 commit comments

Comments
 (0)