1- # TODO : We should unify this function with _make_argument_lattice_elem to ensure consistency
21function _flatten_parameter! (𝕃, compact, argtypes, ntharg, line)
32 list = Any[]
43 for (argn, argt) in enumerate (argtypes)
@@ -8,6 +7,8 @@ function _flatten_parameter!(𝕃, compact, argtypes, ntharg, line)
87 continue
98 elseif Base. isprimitivetype (argt) || isa (argt, Incidence)
109 push! (list, ntharg (argn))
10+ elseif argt === equation || isa (argt, Eq)
11+ continue
1112 elseif isa (argt, Type) && argt <: Intrinsics.AbstractScope
1213 continue
1314 elseif isabstracttype (argt) || ismutabletype (argt) || (! isa (argt, DataType) && ! isa (argt, PartialStruct))
@@ -36,64 +37,79 @@ function flatten_parameter!(𝕃, compact, argtypes, ntharg, line)
3637end
3738
3839# Needs to match flatten_arguments!
39- function process_template_arg! (𝕃, coeffs, eq_mapping, applied_scopes, argt, template_argt, offset= 0 )
40+ function process_template_arg! (𝕃, coeffs, eq_mapping, applied_scopes, argt, template_argt, offset= 0 , eqoffset = 0 ) :: Pair{Int, Int}
4041 if isa (template_argt, Const)
4142 @assert isa (argt, Const) && argt. val === template_argt. val
42- return offset
43+ return Pair {Int, Int} ( offset, eqoffset)
4344 elseif Base. issingletontype (template_argt)
4445 @assert isa (template_argt, Type) && argt. instance === template_argt. instance
45- return offset
46+ return Pair {Int, Int} ( offset, eqoffset)
4647 elseif Base. isprimitivetype (template_argt)
4748 coeffs[offset+ 1 ] = argt
48- return offset + 1
49+ return Pair {Int, Int} (offset + 1 , eqoffset)
50+ elseif template_argt === equation
51+ eq_mapping[eqoffset+ 1 ] = argt. id
52+ return Pair {Int, Int} (offset, eqoffset + 1 )
4953 elseif isabstracttype (template_argt) || ismutabletype (template_argt) || (! isa (template_argt, DataType) && ! isa (template_argt, PartialStruct))
50- return offset
54+ return Pair {Int, Int} ( offset, eqoffset)
5155 else
5256 if ! isa (template_argt, PartialStruct) && Base. datatype_fieldcount (template_argt) === nothing
53- return offset
57+ return Pair {Int, Int} ( offset, eqoffset)
5458 end
5559 template_fields = isa (template_argt, PartialStruct) ? template_argt. fields : collect (fieldtypes (template_argt))
5660 return process_template! (𝕃, coeffs, eq_mapping, applied_scopes, Any[Compiler. getfield_tfunc (𝕃, argt, Const (i)) for i = 1 : length (template_fields)], template_fields, offset)
5761 end
5862end
5963
60- function process_template! (𝕃, coeffs, eq_mapping, applied_scopes, argtypes, template_argtypes, offset= 0 )
64+ function process_template! (𝕃, coeffs, eq_mapping, applied_scopes, argtypes, template_argtypes, offset= 0 , eqoffset = 0 )
6165 @assert length (argtypes) == length (template_argtypes)
6266 for (i, template_arg) in enumerate (template_argtypes)
63- offset = process_template_arg! (𝕃, coeffs, eq_mapping, applied_scopes, argtypes[i], template_arg, offset)
67+ ( offset, eqoffset) = process_template_arg! (𝕃, coeffs, eq_mapping, applied_scopes, argtypes[i], template_arg, offset)
6468 end
65- return offset
69+ return Pair {Int, Int} (offset, eqoffset)
70+ end
71+
72+ struct TransformedArg
73+ ssa:: Any
74+ offset:: Int
75+ eqoffset:: Int
76+ TransformedArg (@nospecialize (arg), new_offset:: Int , new_eqoffset:: Int ) = new (arg, new_offset, new_eqoffset)
6677end
6778
68- function flatten_argument! (compact:: Compiler.IncrementalCompact , argt, offset, argtypes:: Vector{Any} ):: Pair{Any, Int}
69- @assert ! isa (argt, Incidence)
79+ function flatten_argument! (compact:: Compiler.IncrementalCompact , @nospecialize ( argt) , offset:: Int , eqoffset :: Int , argtypes:: Vector{Any} ):: TransformedArg
80+ @assert ! isa (argt, Incidence) && ! isa (argt, Eq)
7081 if isa (argt, Const)
71- return Pair {Any, Int} (argt. val, offset)
82+ return TransformedArg (argt. val, offset, eqoffset )
7283 elseif Base. issingletontype (argt)
73- return Pair {Any, Int} (argt. instance, offset)
84+ return TransformedArg (argt. instance, offset, eqoffset )
7485 elseif Base. isprimitivetype (argt)
7586 push! (argtypes, argt)
76- return Pair {Any, Int} (Argument (offset+ 1 ), offset+ 1 )
87+ return TransformedArg (Argument (offset+ 1 ), offset+ 1 , eqoffset)
88+ elseif argt === equation
89+ ssa = insert_node_here! (compact, NewInstruction (Expr (:invoke , nothing , InternalIntrinsics. external_equation), Eq (eqoffset+ 1 ), compact[Compiler. OldSSAValue (1 )][:line ]))
90+ return TransformedArg (ssa, offset, eqoffset+ 1 )
7791 elseif isabstracttype (argt) || ismutabletype (argt) || (! isa (argt, DataType) && ! isa (argt, PartialStruct))
7892 ssa = insert_node_here! (compact, NewInstruction (Expr (:call , error, " Cannot IPO model arg type $argt " ), Union{}, compact[Compiler. OldSSAValue (1 )][:line ]))
79- return Pair {Any, Int} (ssa, offset )
93+ return TransformedArg (ssa, - 1 , eqoffset )
8094 else
8195 if ! isa (argt, PartialStruct) && Base. datatype_fieldcount (argt) === nothing
8296 ssa = insert_node_here! (compact, NewInstruction (Expr (:call , error, " Cannot IPO model arg type $argt " ), Union{}, compact[Compiler. OldSSAValue (1 )][:line ]))
83- return Pair {Any, Int} (ssa, offset )
97+ return TransformedArg (ssa, - 1 , eqoffset )
8498 end
85- (args, _, offset) = flatten_arguments! (compact, isa (argt, PartialStruct) ? argt. fields : fieldtypes (argt), offset, argtypes)
99+ (args, _, offset) = flatten_arguments! (compact, isa (argt, PartialStruct) ? argt. fields : collect (Any, fieldtypes (argt)), offset, eqoffset, argtypes)
100+ offset == - 1 && return TransformedArg (ssa, - 1 , eqoffset)
86101 this = Expr (:new , isa (argt, PartialStruct) ? argt. typ : argt, args... )
87102 ssa = insert_node_here! (compact, NewInstruction (this, argt, compact[Compiler. OldSSAValue (1 )][:line ]))
88- return Pair {Any, Int} (ssa, offset)
103+ return TransformedArg (ssa, offset, eqoffset )
89104 end
90105end
91106
92- function flatten_arguments! (compact:: Compiler.IncrementalCompact , argtypes, offset= 0 , new_argtypes:: Vector{Any} = Any[])
107+ function flatten_arguments! (compact:: Compiler.IncrementalCompact , argtypes:: Vector{Any} , offset:: Int = 0 , eqoffset :: Int = 0 , new_argtypes:: Vector{Any} = Any[])
93108 args = Any[]
94109 for argt in argtypes
95- (ssa, offset) = flatten_argument! (compact, argt, offset, new_argtypes)
110+ (; ssa, offset, eqoffset) = flatten_argument! (compact, argt, offset, eqoffset, new_argtypes)
111+ offset == - 1 && break
96112 push! (args, ssa)
97113 end
98- return (args, new_argtypes, offset)
114+ return (args, new_argtypes, offset, eqoffset )
99115end
0 commit comments