Skip to content

Commit 977cd72

Browse files
authored
Normalize nested argument lists early (#40)
As discussed on Slack, this augments structural analysis to flatten argument lists when function have nested argument structures. The primary motivation is to make it easier to track non-linear expressions used as IPO arguments (see newly added tests).
1 parent fce94c1 commit 977cd72

File tree

17 files changed

+398
-228
lines changed

17 files changed

+398
-228
lines changed

src/DAECompiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ module DAECompiler
2323
include("analysis/refiner.jl")
2424
include("analysis/ipoincidence.jl")
2525
include("analysis/structural.jl")
26+
include("analysis/flattening.jl")
2627
include("transform/state_selection.jl")
2728
include("transform/common.jl")
2829
include("transform/runtime.jl")

src/analysis/cache.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ end
2121
var_schedule::Vector{Pair{BitSet, BitSet}}
2222
end
2323

24+
"""
25+
StructuralSSARef
26+
27+
Represents an SSA reference to the IR after structural analysis. Used as keys for callees, etc.
28+
"""
29+
struct StructuralSSARef
30+
id::Int
31+
end
32+
2433
struct DAEIPOResult
2534
ir::IRCode
2635
extended_rt::Any
@@ -34,7 +43,7 @@ struct DAEIPOResult
3443
varclassification::Vector{VarEqClassification}
3544
total_incidence::Vector{Any}
3645
eqclassification::Vector{VarEqClassification}
37-
eq_callee_mapping::Vector{Union{Nothing, Vector{Pair{SSAValue, Int}}}}
46+
eq_callee_mapping::Vector{Union{Nothing, Vector{Pair{StructuralSSARef, Int}}}}
3847
names::OrderedDict{Any, ScopeDictEntry} # TODO: OrderedIdDict
3948
varkinds::Vector{Union{Intrinsics.VarKind, Nothing}}
4049
eqkinds::Vector{Union{Intrinsics.EqKind, Nothing}}

src/analysis/consistency.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function get_inline_backtrace(ir::IRCode, v::SSAValue)
2727
runtime_jl_path = maybe_realpath(joinpath(dirname(pathof(@__MODULE__)), "runtime.jl"))
2828

2929
frames = Base.StackTrace();
30-
for lineinfo in Base.IRShow.buildLineInfoNode(ir.debuginfo, nothing, v.id)
30+
for lineinfo in Compiler.IRShow.buildLineInfoNode(ir.debuginfo, nothing, v.id)
3131
btpath = maybe_realpath(expanduser(string(lineinfo.file)))
3232
if btpath != runtime_jl_path
3333
frame = Base.StackFrame(lineinfo.method, lineinfo.file, lineinfo.line)

src/analysis/flattening.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# TODO: We should unify this function with _make_argument_lattice_elem to ensure consistency
2+
function _flatten_parameter!(𝕃, compact, argtypes, ntharg, line)
3+
list = Any[]
4+
for (argn, argt) in enumerate(argtypes)
5+
if isa(argt, Const)
6+
continue
7+
elseif Base.issingletontype(argt)
8+
continue
9+
elseif Base.isprimitivetype(argt) || isa(argt, Incidence)
10+
push!(list, ntharg(argn))
11+
elseif isa(argt, Type) && argt <: Intrinsics.AbstractScope
12+
continue
13+
elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct))
14+
continue
15+
else
16+
if !isa(argt, PartialStruct) && Base.datatype_fieldcount(argt) === nothing
17+
continue
18+
end
19+
this = ntharg(argn)
20+
nthfield(i) = insert_node_here!(compact,
21+
NewInstruction(Expr(:call, getfield, this, i), Compiler.getfield_tfunc(𝕃, argextype(this, compact), Const(i)), line))
22+
if isa(argt, PartialStruct)
23+
fields = _flatten_parameter!(𝕃, compact, argt.fields, nthfield, line)
24+
else
25+
fields = _flatten_parameter!(𝕃, compact, fieldtypes(argt), nthfield, line)
26+
end
27+
append!(list, fields)
28+
end
29+
end
30+
return list
31+
end
32+
33+
function flatten_parameter!(𝕃, compact, argtypes, ntharg, line)
34+
return insert_node_here!(compact,
35+
NewInstruction(Expr(:call, tuple, _flatten_parameter!(𝕃, compact, argtypes, ntharg, line)...), Tuple, line))
36+
end
37+
38+
# Needs to match flatten_arguments!
39+
function process_template_arg!(𝕃, coeffs, eq_mapping, applied_scopes, argt, template_argt, offset=0)
40+
if isa(template_argt, Const)
41+
@assert isa(argt, Const) && argt.val === template_argt.val
42+
return offset
43+
elseif Base.issingletontype(template_argt)
44+
@assert isa(template_argt, Type) && argt.instance === template_argt.instance
45+
return offset
46+
elseif Base.isprimitivetype(template_argt)
47+
coeffs[offset+1] = argt
48+
return offset + 1
49+
elseif isabstracttype(template_argt) || ismutabletype(template_argt) || (!isa(template_argt, DataType) && !isa(template_argt, PartialStruct))
50+
return offset
51+
else
52+
if !isa(template_argt, PartialStruct) && Base.datatype_fieldcount(template_argt) === nothing
53+
return offset
54+
end
55+
template_fields = isa(template_argt, PartialStruct) ? template_argt.fields : collect(fieldtypes(template_argt))
56+
return process_template!(𝕃, coeffs, eq_mapping, applied_scopes, Any[Compiler.getfield_tfunc(𝕃, argt, Const(i)) for i = 1:length(template_fields)], template_fields, offset)
57+
end
58+
end
59+
60+
function process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, template_argtypes, offset=0)
61+
@assert length(argtypes) == length(template_argtypes)
62+
for (i, template_arg) in enumerate(template_argtypes)
63+
offset = process_template_arg!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes[i], template_arg, offset)
64+
end
65+
return offset
66+
end
67+
68+
function flatten_argument!(compact::Compiler.IncrementalCompact, argt, offset, argtypes::Vector{Any})::Pair{Any, Int}
69+
@assert !isa(argt, Incidence)
70+
if isa(argt, Const)
71+
return Pair{Any, Int}(argt.val, offset)
72+
elseif Base.issingletontype(argt)
73+
return Pair{Any, Int}(argt.instance, offset)
74+
elseif Base.isprimitivetype(argt)
75+
push!(argtypes, argt)
76+
return Pair{Any, Int}(Argument(offset+1), offset+1)
77+
elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct))
78+
ssa = insert_node_here!(compact, NewInstruction(Expr(:call, error, "Cannot IPO model arg type $argt"), Union{}, compact[Compiler.OldSSAValue(1)][:line]))
79+
return Pair{Any, Int}(ssa, offset)
80+
else
81+
if !isa(argt, PartialStruct) && Base.datatype_fieldcount(argt) === nothing
82+
ssa = insert_node_here!(compact, NewInstruction(Expr(:call, error, "Cannot IPO model arg type $argt"), Union{}, compact[Compiler.OldSSAValue(1)][:line]))
83+
return Pair{Any, Int}(ssa, offset)
84+
end
85+
(args, _, offset) = flatten_arguments!(compact, isa(argt, PartialStruct) ? argt.fields : fieldtypes(argt), offset, argtypes)
86+
this = Expr(:new, isa(argt, PartialStruct) ? argt.typ : argt, args...)
87+
ssa = insert_node_here!(compact, NewInstruction(this, argt, compact[Compiler.OldSSAValue(1)][:line]))
88+
return Pair{Any, Int}(ssa, offset)
89+
end
90+
end
91+
92+
function flatten_arguments!(compact::Compiler.IncrementalCompact, argtypes, offset=0, new_argtypes::Vector{Any} = Any[])
93+
args = Any[]
94+
for argt in argtypes
95+
(ssa, offset) = flatten_argument!(compact, argt, offset, new_argtypes)
96+
push!(args, ssa)
97+
end
98+
return (args, new_argtypes, offset)
99+
end

src/analysis/ipoincidence.jl

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -105,54 +105,12 @@ function apply_linear_incidence(𝕃, ret::PartialStruct, caller::CallerMappingS
105105
return PartialStruct(𝕃, ret.typ, Any[apply_linear_incidence(𝕃, f, caller, mapping) for f in ret.fields])
106106
end
107107

108-
function process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, template_argtypes)
109-
for (arg, template) in zip(argtypes, template_argtypes)
110-
if isa(template, Incidence)
111-
if isempty(template)
112-
# @assert iszero(arg)
113-
continue
114-
end
115-
(idxs, vals) = findnz(template.row)
116-
@assert only(vals) == 1.0
117-
@assert !isassigned(coeffs, only(idxs)-1)
118-
coeffs[only(idxs)-1] = arg
119-
elseif isa(template, Eq)
120-
@assert isa(arg, Eq)
121-
eq_mapping[idnum(template)] = idnum(arg)
122-
elseif Compiler.is_const_argtype(template)
123-
#@CC.show (arg, template)
124-
#@assert CC.is_lattice_equal(DAE_LATTICE, arg, template)
125-
elseif isa(template, PartialScope)
126-
id = idnum(template)
127-
(id > length(applied_scopes)) && resize!(applied_scopes, id)
128-
if isa(arg, Const)
129-
@assert isa(arg.val, Union{Scope, Nothing})
130-
applied_scopes[id] = arg.val
131-
elseif isa(arg, PartialScope)
132-
applied_scopes[id] = arg
133-
else
134-
applied_scopes[id] = arg
135-
end
136-
elseif isa(template, PartialStruct)
137-
if isa(arg, PartialStruct)
138-
fields = arg.fields
139-
else
140-
fields = Any[Compiler.getfield_tfunc(𝕃, arg, Const(i)) for i = 1:length(template.fields)]
141-
end
142-
process_template!(𝕃, coeffs, eq_mapping, applied_scopes, fields, template.fields)
143-
else
144-
@show (arg, template, template_argtypes)
145-
error()
146-
end
147-
end
148-
end
149-
150-
function CalleeMapping(𝕃::Compiler.AbstractLattice, argtypes::Vector{Any}, callee_result::DAEIPOResult)
108+
function CalleeMapping(𝕃::Compiler.AbstractLattice, argtypes::Vector{Any}, callee_result::DAEIPOResult, template_argtypes)
151109
applied_scopes = Any[]
152110
coeffs = Vector{Any}(undef, length(callee_result.var_to_diff))
153111
eq_mapping = fill(0, length(callee_result.total_incidence))
154112

155-
process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, callee_result.argtypes)
113+
process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, template_argtypes)
156114

157115
return CalleeMapping(coeffs, eq_mapping, applied_scopes)
158116
end

src/analysis/lattice.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ struct Incidence
109109
typ::Union{Type, Const}
110110
row::IncidenceVector
111111

112-
function Incidence(@nospecialize(type), row)
112+
function Incidence(@nospecialize(type), row::AbstractVector)
113113
if is_non_incidence_type(type)
114114
throw(DomainError(type, "Invalid type for Incidence"))
115115
end
@@ -198,6 +198,12 @@ function Incidence(v::Int)
198198
row[v+1] = 1.0
199199
return Incidence(_ZERO_CONST, row)
200200
end
201+
function Incidence(T::Union{Type, Compiler.Const}, v::Int)
202+
T === Float64 && return Incidence(v)
203+
row = _zero_row()
204+
row[v+1] = nonlinear
205+
return Incidence(T, row)
206+
end
201207

202208
"Identify the id number of an equation or variable"
203209
idnum(a::Incidence) = only(findall(!iszero, a.row)) - 1 # Inverse of above constructor.

src/analysis/refiner.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ Compiler.cache_owner(::StructuralRefiner) = StructureCache()
5757
end
5858

5959
argtypes = Compiler.collect_argtypes(interp, stmt.args, Compiler.StatementState(nothing, false), irsv)[2:end]
60-
mapping = CalleeMapping(Compiler.optimizer_lattice(interp), argtypes, callee_result)
60+
m = Compiler.get_ci_mi(callee_codeinst).def
61+
argtypes = Compiler.va_process_argtypes(Compiler.optimizer_lattice(interp), argtypes, UInt(m.nargs), m.isva)
62+
mapping = CalleeMapping(Compiler.optimizer_lattice(interp), argtypes, callee_result, callee_codeinst.inferred.ir.argtypes)
6163
new_rt = apply_linear_incidence(Compiler.optimizer_lattice(interp), callee_result.extended_rt,
6264
CallerMappingState(callee_result, interp.var_to_diff, interp.varclassification, interp.varkinds, VarEqClassification[]), mapping)
6365

@@ -164,7 +166,7 @@ function tfunc(F::Union{Val{Core.Intrinsics.sub_float}, Val{Core.Intrinsics.sub_
164166
return Incidence(const_val, rrow)
165167
end
166168

167-
function tfunc(::Val{Core.Intrinsics.mul_float}, @nospecialize(a::Union{Const, Type{Float64}, Incidence}), @nospecialize(b::Union{Const, Type{Float64}, Incidence}))
169+
function tfunc(::Union{Val{Core.Intrinsics.mul_float}, Val{Core.Intrinsics.mul_int}}, @nospecialize(a::Union{Const, Type{Float64}, Incidence}), @nospecialize(b::Union{Const, Type{Float64}, Incidence}))
168170
if a === Float64 || b === Float64
169171
return Float64
170172
end
@@ -341,7 +343,7 @@ is_any_incidence(@nospecialize args...) = any(@nospecialize(x)->isa(x, Incidence
341343
if is_any_incidence(a)
342344
if f == Core.Intrinsics.neg_float || f === Core.Intrinsics.neg_int
343345
return tfunc(Val(f), a)
344-
elseif f === Core.Intrinsics.ctlz_int || f === Core.Intrinsics.not_int || f === Core.Intrinsics.abs_float
346+
elseif f === Core.Intrinsics.ctlz_int || f === Core.Intrinsics.not_int || f === Core.Intrinsics.abs_float || f === Core.Intrinsics.rint_llvm
345347
return generic_math_onearg(f, a)
346348
end
347349
end
@@ -351,7 +353,7 @@ is_any_incidence(@nospecialize args...) = any(@nospecialize(x)->isa(x, Incidence
351353
if is_any_incidence(a, b)
352354
if (f == Core.Intrinsics.add_float || f == Core.Intrinsics.sub_float) ||
353355
(f == Core.Intrinsics.add_int || f == Core.Intrinsics.sub_int) ||
354-
(f == Core.Intrinsics.mul_float || f == Core.Intrinsics.div_float) ||
356+
(f == Core.Intrinsics.mul_float || f == Core.Intrinsics.div_float || f == Core.Intrinsics.mul_int) ||
355357
f == Core.Intrinsics.copysign_float
356358
return tfunc(Val(f), a, b)
357359
elseif f in (Core.Intrinsics.or_int, Core.Intrinsics.and_int, Core.Intrinsics.xor_int,

0 commit comments

Comments
 (0)