diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala index 8861c2b366..fcb1968559 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala @@ -11,7 +11,6 @@ import hkmc2.utils.* import semantics.* import semantics.Elaborator.{State, Ctx, ctx} import mlscript.utils.algorithms.partitionScc -import java.util.IdentityHashMap import hkmc2.syntax.Literal import hkmc2.{codegen => argss} @@ -405,17 +404,78 @@ class BlockSimplifier end apply - + // * A reference we may substitute for another reference, together with + // * the assignment facts that must still be current for the substitution + // * to be sound. Requirements are compared by object identity below, so + // * they precisely describe the data-flow state observed when the fact + // * was recorded. + case class TrackedRef(ref: Value.RefLike, requirements: Set[LocalVar -> AssignInfo]): + def isCurrent: Bool = + requirements.forall((loc, asst) => assignedResults(loc) is asst) + + // * Summary of the direct value that can be propagated for a local. + // * - `false` means no known value. + // * - `true` means definitely uninitialized, + // * meaning any variable access can be replaced by `undefined`, ie, `Value.Lit(UnitLit(false))`. + // * - `Value` means this exact value is still available for propagation. + type KnownValue = Bool | Value + + // * The propagated value fact for an assignment, plus equivalent + // * references that could be substituted while their requirements hold. + case class ValueAnalysis(litValue: KnownValue, refs: List[TrackedRef]) + + object ValueAnalysis: + + val conservative: ValueAnalysis = ValueAnalysis(false, Nil) + + // * Keep a value fact only when all merged control-flow paths agree. + def mergeLitValues(l: KnownValue, r: KnownValue): KnownValue = + (l, r) match + case (false, _) | (_, false) => false + case (true, true) => true + case (true, v: Value) => v + case (v: Value, true) => v + case (v1: Value, v2: Value) if v1 === v2 => v1 + case _ => false + + def mergeRefs(l: List[TrackedRef], r: List[TrackedRef]): List[TrackedRef] = + l.flatMap: lr => + r.collect: + case rr if lr.ref === rr.ref => + TrackedRef(lr.ref, lr.requirements ++ rr.requirements) + + end ValueAnalysis + + // * An unsaturated pure call that can be spliced into a later call, as in + // * `let f = foo(x); f(y)` ~> `foo(x)(y)`, provided every local captured + // * by the prefix still denotes the same assignment fact. + case class TrackedPureCall(call: Call, requirements: Set[LocalVar -> AssignInfo]): + def isCurrent: Bool = + requirements.forall((loc, asst) => !capturedVars(loc) && (assignedResults(loc) is asst)) + + // * Data-flow fact for the latest assignment known for a local variable. + // * Facts are intentionally immutable so derived analyses can be cached in + // * lazy values and compared by identity when validating requirements. enum AssignInfo: case Unknown case Uninitialized - case Assigned(asst: Assign, varAsst: Opt[Value.RefLike -> AssignInfo]) + // * `varAsst` is defined if the RHS is a direct reference to some local L, + // * so that the current variable can be treated as an alias of L as long as L's + // * associated `AssignInfo` assignment facts remain current. + // * `rhsRequirements` tracks + // * all local variables mentioned by the RHS so pure-call prefixes do not + // * outlive locals that were only valid in a narrower scope. + case Assigned( + asst: Assign, + varAsst: Opt[Value.RefLike -> AssignInfo], + rhsRequirements: Set[LocalVar -> AssignInfo], + ) case Merge(asst1: AssignInfo, asst2: AssignInfo) override def toString: String = this match case Unknown => "?" case Uninitialized => "∅" - case Assigned(asst, varAsst) => s"${asst.rhs}${varAsst.fold("")("‹"+_+"›")}" + case Assigned(asst, varAsst, _) => s"${asst.rhs}${varAsst.fold("")("‹"+_+"›")}" case Merge(a1, a2) => s"{${a1.toString} | ${a2.toString}}" def merge(that: AssignInfo): AssignInfo = @@ -430,6 +490,79 @@ class BlockSimplifier case Uninitialized => that case _: Assigned | _: Merge => Merge(this, that) + // * This lazy val is used to avoid retraversing the DAG and to deduplicate entries. + // * There are more efficient ways of traversing the DAG (e.g. using a mutable visited set), + // * which could avoid merging so many intermediate sets, + // * but this is simpler and should be sufficient for now. + lazy val assigns: Opt[Set[Assigned]] = this match + case a: Assigned => S(Set.single(a)) + case Merge(asst1, asst2) => + // * `for` is only for rich kids + asst1.assigns match + case N => N + case S(set1) => + asst2.assigns match + case N => S(set1) + case S(set2) => S(set1 ++ set2) + case Uninitialized => S(Set.empty) + case Unknown => N + + lazy val valueAnalysis: ValueAnalysis = this match + case Unknown => + ValueAnalysis.conservative + case Uninitialized => + ValueAnalysis(true, Nil) + case Assigned(ass, opt, _) => + val litValue = ass.rhs match + case v @ Value.Lit(_) => v + case _ => false + val refs = opt match + case S((r @ Value.SimpleRef(lv: LocalVar)) -> rhs) => + val requirement = lv -> rhs + TrackedRef(r, Set.single(requirement)) :: rhs.valueAnalysis.refs + case S(ref -> rhs) => + TrackedRef(ref, + // * Other types of direct references don't need requirements because they cannot be reassigned: + // * indeed, `mut val` is not valid outside of an object/module scope, + // * and if defined in such a scope, a `mut val x` would be referred to through `this.x`. + Set.empty + ) :: rhs.valueAnalysis.refs + case N => Nil + ValueAnalysis(litValue, refs) + case Merge(asst1, asst2) => + // * [Future: dead assignment removal] + // FIXME: this currently short-circuits, which will miss some live assignments... + val l = asst1.valueAnalysis + if l.refs.isEmpty && l.litValue === false then + ValueAnalysis.conservative + else + val r = asst2.valueAnalysis + ValueAnalysis( + ValueAnalysis.mergeLitValues(l.litValue, r.litValue), + ValueAnalysis.mergeRefs(l.refs, r.refs)) + + lazy val pureCallPrefix: Opt[TrackedPureCall] = this match + case Unknown | Uninitialized => N + case Assigned(ass, opt, rhsRequirements) => + ass.rhs match + case call: Call if call.isKnownUnsaturatedCall && call.isPure => + S(TrackedPureCall(call, rhsRequirements)) + case _ => + opt match + case S((Value.SimpleRef(next: LocalVar), originalAsst)) => + // * If the RHS was a variable that was at the time assigned to a pure call prefix, + // * we can directly pick up that call, regardless of the current status of that variable. + originalAsst.pureCallPrefix + case _ => N + case Merge(asst1, asst2) => + asst1.pureCallPrefix match + case S(call1) => + asst2.pureCallPrefix match + case S(call2) if call1.call === call2.call => + S(TrackedPureCall(call1.call, call1.requirements ++ call2.requirements)) + case _ => N + case N => N + import AssignInfo.* @@ -513,16 +646,16 @@ class BlockSimplifier case ass @ Assign(lhs: LocalVar, rhs, rst) if !capturedVars(lhs) => // log(s"Propagating ${lhs} := ${rhs} (${assignedResults.get(lhs)})") - assignedResults += lhs -> Assigned(ass, rhs.match + val varAsst = rhs.match case r @ Value.SimpleRef(sym: LocalVar) => if capturedVars(sym) then N - else - val rhs2 = assignedResults(sym) - S(r -> rhs2) - case r: Value.RefLike => - S(r -> Unknown) + else S(r -> assignedResults(sym)) + case r: Value.RefLike => S(r -> Unknown) case _ => N - ) + val rhsRequirements = rhs.freeVars.iterator.collect: + case sym: LocalVar if !capturedVars(sym) => + sym -> assignedResults(sym) + assignedResults += lhs -> Assigned(ass, varAsst, rhsRequirements.toSet) super.applyBlock(b) @@ -611,33 +744,32 @@ class BlockSimplifier sym.irClsLikeDefn.exists: defn => val paramLists = defn.paramsOpt.toList ::: defn.auxParams paramLists.lengthCompare(argss.length) === 0 - def getShapesA(a: AssignInfo): Set[Shape] = - // trace[Set[Shape]](s"Getting shapes for assignment ${a}", r => s"= ${r}"): - a match - case Unknown => giveUp - case Uninitialized => Set.empty - case Merge(a1, a2) => getShapesA(a1) | getShapesA(a2) - case Assigned(asst, varAsst) => - varAsst match - case S(Value.MemberRef(r, sym: ModuleOrObjectSymbol) -> _) => - Set.single(sym) - case S(_ -> ass) => - getShapesA(ass) - case N => - asst.rhs match - case p: Path => getShapes(p) - case Call(path, args) => - getCtorShape(path) match - case S(sym: ClassSymbol) if isSaturatedClassCall(sym, args) => - Set.single(sym) - case _ => giveUp - case Instantiate(_, cls, _) => - // * Note: Instantiate nodes are globally assumed to be saturated - getCtorShape(cls) match - case S(sym) => - Set.single(sym) + def getAssignInfoShapes(a: AssignInfo): Set[Shape] = + if gaveUp then Set.empty + a.assigns match + case N => giveUp + case S(assts) => assts.flatMap: + case Assigned(asst, varAsst, _) => + varAsst match + case S(Value.MemberRef(r, sym: ModuleOrObjectSymbol) -> _) => + Set.single(sym) + case S(_ -> ass) => + getAssignInfoShapes(ass) + case N => + asst.rhs match + case p: Path => getShapes(p) + case Call(path, args) => + getCtorShape(path) match + case S(sym: ClassSymbol) if isSaturatedClassCall(sym, args) => + Set.single(sym) + case _ => giveUp + case Instantiate(_, cls, _) => + // * Note: Instantiate nodes are globally assumed to be saturated + getCtorShape(cls) match + case S(sym) => + Set.single(sym) + case _ => giveUp case _ => giveUp - case _ => giveUp def getShapes(p: Path): Set[Shape] = if gaveUp then Set.empty else @@ -645,7 +777,7 @@ class BlockSimplifier case Value.SimpleRef(r: LocalVar) if capturedVars(r) => giveUp case Value.SimpleRef(r: LocalVar) => - assignedResults.get(r).fold(giveUp)(getShapesA) + assignedResults.get(r).fold(giveUp)(getAssignInfoShapes) case Value.MemberRef(r, sym: ModuleOrObjectSymbol) => Set.single(sym) case Value.Lit(lit) => Set.single(lit) @@ -731,65 +863,12 @@ class BlockSimplifier val rs = assignedResults(loc) // log(s"Ref ${loc.showDbg} ${rs} ${localVars(loc)} ${capturedVars(loc)}") - def analyzeAssignments(asst: AssignInfo): Unit = - asst match - case Unknown | Uninitialized => () - case Merge(a1, a2) => - analyzeAssignments(a1) - analyzeAssignments(a2) - case Assigned(ass, _) => - // * [Future: dead assignment removal] - // liveAssignments.put(ass, ()) + val analysis = rs.valueAnalysis + val refs = analysis.refs.iterator.filter(_.isCurrent).map(_.ref).toList - var litValue: Bool | Value = true - var emptyHanded = false - - def analyzeValues(asst: AssignInfo): Set[Value.RefLike] = - if emptyHanded && litValue === false then - analyzeAssignments(asst) - Set.empty - else asst match - case Unknown => - litValue = false - Set.empty - case Uninitialized => Set.empty - case Assigned(ass, opt) => - // * [Future: dead assignment removal] - // liveAssignments.put(ass, ()) - - if litValue =/= false then - ass.rhs match - case v @ Value.Lit(lit) => - if litValue === true then - litValue = v - else if litValue =/= v then - litValue = false - case _ => - litValue = false - opt match - case S((r @ Value.SimpleRef(lv: LocalVar)) -> rhs) => - if assignedResults(lv) is rhs - then Set.single(r) ++ analyzeValues(rhs) - else Set.empty - case S(lv -> rhs) => - Set.single(lv) ++ analyzeValues(rhs) - case N => Set.empty - case Merge(a1, a2) => - // * [Future: dead assignment removal] - // FIXME: this currently short-circuits, which will miss some live assignments... - - val l = analyzeValues(a1) - if l.isEmpty && litValue === false then - emptyHanded = true - analyzeAssignments(a2) - Set.empty - else l & analyzeValues(a2) + // log(s"Analysis: litValue: ${analysis.litValue}, unchanged vars: ${refs}") - val vars = analyzeValues(rs) - - // log(s"Analysis: litValue: ${litValue}, unchanged vars: ${vars}") - - litValue match + analysis.litValue match case true => registerChange(s"${loc.showDbg} ~> undefined") return k(Value.Lit(syntax.Tree.UnitLit(false))) @@ -797,33 +876,20 @@ class BlockSimplifier registerChange(s"${loc.showDbg} ~> ${lit.showDbg}") return k(lit) case false => - vars.minByOption(_.symbol.uid) match + refs.minByOption(_.symbol.uid) match case N => k(v) case S(v2) => - registerChange(s"${loc.showDbg} ~> ${v2.showDbg} (via ${vars.map(_.showDbg).mkString(", ")})") + registerChange(s"${loc.showDbg} ~> ${v2.showDbg} (via ${refs.map(_.showDbg).mkString(", ")})") k(v2) case _ => super.applyValue(v)(k) private def assignedPureCallPrefix(loc: LocalVar): Opt[Call] = - def loop(asst: AssignInfo, seen: Set[LocalVar]): Opt[Call] = - asst match - case Unknown | Uninitialized => N - case Assigned(ass, opt) => - ass.rhs match - case call: Call if call.isKnownUnsaturatedCall && call.isPure => S(call) - case _ => - opt match - case S((Value.SimpleRef(next: LocalVar), nextAsst)) - if !capturedVars(next) && !seen(next) && (assignedResults(next) is nextAsst) => - loop(nextAsst, seen + next) - case _ => N - case Merge(asst1, asst2) => - (loop(asst1, seen), loop(asst2, seen)) match - case (S(call1), S(call2)) if call1 == call2 => S(call1) - case _ => N - loop(assignedResults(loc), Set.single(loc)) + // * Only expose prefixes whose dependency facts still match the current + // * data-flow state; otherwise the prefix may mention stale scoped locals. + assignedResults(loc).pureCallPrefix.collect: + case prefix if prefix.isCurrent => prefix.call override def applyResult(r: Result)(k: Result => Block): Block = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 0f7579e321..ebaaea61d0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -750,7 +750,10 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val ctx = HandlerCtx.TopLevel val transformed = translateBlock(preTransformed, ctx, Set.empty) val blk = blockBuilder - .assign(State.noSymbol, Call(paths.resetEffects, Nil ne_:: Nil)(true, false, false)) + .staticif( + !opt.doNotInstrumentTopLevelModCtor, + _.assign(State.noSymbol, Call(paths.resetEffects, Nil ne_:: Nil)(true, false, false)) + ) .rest(transformed) (blk, stackSafetyMap) diff --git a/hkmc2Benchmarks/src/test/bench/mlscript-compile/Benchmark.mls b/hkmc2/shared/src/test/mlscript-compile/Benchmark.mls similarity index 100% rename from hkmc2Benchmarks/src/test/bench/mlscript-compile/Benchmark.mls rename to hkmc2/shared/src/test/mlscript-compile/Benchmark.mls diff --git a/hkmc2/shared/src/test/mlscript/opt/BasicVarPropag.mls b/hkmc2/shared/src/test/mlscript/opt/BasicVarPropag.mls index 348bf99df3..f9f4427b3c 100644 --- a/hkmc2/shared/src/test/mlscript/opt/BasicVarPropag.mls +++ b/hkmc2/shared/src/test/mlscript/opt/BasicVarPropag.mls @@ -483,3 +483,4 @@ in y + 1 //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— //│ = 2 + diff --git a/hkmc2/shared/src/test/mlscript/opt/PureCallPropagation.mls b/hkmc2/shared/src/test/mlscript/opt/PureCallPropagation.mls index bef91c58c5..c77806027b 100644 --- a/hkmc2/shared/src/test/mlscript/opt/PureCallPropagation.mls +++ b/hkmc2/shared/src/test/mlscript/opt/PureCallPropagation.mls @@ -140,3 +140,46 @@ f(3) + f(4) //│ f = fun +// * Do not propagate an unsaturated call prefix if it mentions a local +// * that has gone out of scope by the later call site. +:soir +fun scopedPrefixFoo(x)(y) = x + y +() => + let f + if maybe is + false then 0 + true then + scope.locally of ( + let x = hide(0) + f = scopedPrefixFoo(x) + ) + print of + "Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello" + f(2) +//│ ——————————————| Optimized IR |—————————————————————————————————————————————————————————————————————— +//│ let scopedPrefixFoo⁰, lambda³; +//│ @inline +//│ define scopedPrefixFoo⁰ as fun scopedPrefixFoo¹(x)(y) { +//│ return +⁰(x, y) +//│ }; +//│ @private +//│ define lambda³ as fun lambda⁴() { +//│ let f, scrut; +//│ set scrut = Predef⁰.maybe⁰; +//│ match scrut +//│ false => +//│ end +//│ true => +//│ let x; +//│ set x = Predef⁰.hide⁰(0); +//│ set f = scopedPrefixFoo¹(x); +//│ end +//│ else +//│ throw new globalThis⁰.Error⁰("match error") +//│ do Predef⁰.print⁰("Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello", "Hello"); +//│ return f(2) +//│ }; +//│ return lambda⁴ +//│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— +//│ = fun + diff --git a/hkmc2/shared/src/test/mlscript/opt/VarPropagStress.mls b/hkmc2/shared/src/test/mlscript/opt/VarPropagStress.mls new file mode 100644 index 0000000000..b067592f04 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/opt/VarPropagStress.mls @@ -0,0 +1,37 @@ +:js + + +// FIXME: Exponential blow up? The time this takes roughly doubles with each additional line (try it with `watcher`) +let x = 0 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +if maybe do set x += 1 +// if maybe do set x += 1 +// if maybe do set x += 1 +// if maybe do set x += 1 +// if maybe do set x += 1 +// if maybe do set x += 1 +x +//│ = 22 +//│ x = 22 + + diff --git a/hkmc2Benchmarks/src/test/bench/FingerTreesAsStacks.mls b/hkmc2Benchmarks/src/test/bench/FingerTreesAsStacks.mls index d22d45ed82..7e40e575a8 100644 --- a/hkmc2Benchmarks/src/test/bench/FingerTreesAsStacks.mls +++ b/hkmc2Benchmarks/src/test/bench/FingerTreesAsStacks.mls @@ -3,7 +3,7 @@ import "../../../../hkmc2/shared/src/test/mlscript-compile/FingerTreeList.mls" import "../../../../hkmc2/shared/src/test/mlscript-compile/Stack.mls" -import "./mlscript-compile/Benchmark.mls" +import "../../../../hkmc2/shared/src/test/mlscript-compile/Benchmark.mls" open Benchmark diff --git a/hkmc2Benchmarks/src/test/bench/LazySpreads.mls b/hkmc2Benchmarks/src/test/bench/LazySpreads.mls index 42cd397a63..b4a01094bc 100644 --- a/hkmc2Benchmarks/src/test/bench/LazySpreads.mls +++ b/hkmc2Benchmarks/src/test/bench/LazySpreads.mls @@ -4,7 +4,7 @@ import "../../../../hkmc2/shared/src/test/mlscript-compile/LazyArray.mls" import "../../../../hkmc2/shared/src/test/mlscript-compile/FingerTreeList.mls" import "../../../../hkmc2/shared/src/test/mlscript-compile/LazyFingerTree.mls" -import "./mlscript-compile/Benchmark.mls" +import "../../../../hkmc2/shared/src/test/mlscript-compile/Benchmark.mls" import "../../../../hkmc2/shared/src/test/mlscript-compile/Runtime.mls" open LazyArray diff --git a/hkmc2Benchmarks/src/test/bench/nofibs-bench1.mls b/hkmc2Benchmarks/src/test/bench/nofibs-bench1.mls index 84aff684bc..20151851c9 100644 --- a/hkmc2Benchmarks/src/test/bench/nofibs-bench1.mls +++ b/hkmc2Benchmarks/src/test/bench/nofibs-bench1.mls @@ -6,7 +6,7 @@ ==================================================================================================== -import "./mlscript-compile/Benchmark.mls" +import "../../../../hkmc2/shared/src/test/mlscript-compile/Benchmark.mls" open Benchmark diff --git a/hkmc2Benchmarks/src/test/bench/nofibs-bench2.mls b/hkmc2Benchmarks/src/test/bench/nofibs-bench2.mls index fc9fcd3b9b..c6c4064a54 100644 --- a/hkmc2Benchmarks/src/test/bench/nofibs-bench2.mls +++ b/hkmc2Benchmarks/src/test/bench/nofibs-bench2.mls @@ -7,7 +7,7 @@ ==================================================================================================== -import "./mlscript-compile/Benchmark.mls" +import "../../../../hkmc2/shared/src/test/mlscript-compile/Benchmark.mls" open Benchmark diff --git a/hkmc2Benchmarks/src/test/bench/nofibs-bench3.mls b/hkmc2Benchmarks/src/test/bench/nofibs-bench3.mls index d0bc7cd8e4..72efc106a4 100644 --- a/hkmc2Benchmarks/src/test/bench/nofibs-bench3.mls +++ b/hkmc2Benchmarks/src/test/bench/nofibs-bench3.mls @@ -7,7 +7,7 @@ ==================================================================================================== -import "./mlscript-compile/Benchmark.mls" +import "../../../../hkmc2/shared/src/test/mlscript-compile/Benchmark.mls" open Benchmark