Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2555530
WIP: Fix diff-test inlining threshold
LPTK May 29, 2026
5ddcf17
WIP: Fix diff-test inlining threshold
LPTK Jun 2, 2026
2fe8785
Changes before error encountered
Copilot May 29, 2026
597e094
Fix SymbolRefresher: copy defn to refreshed ClassSymbol and handle pr…
Copilot Jun 1, 2026
bd1e364
Update golden test snapshots after SymbolRefresher fix
Copilot Jun 1, 2026
0e5a395
Fix method inlining: restrict to module methods only, fix JS codegen …
Copilot Jun 1, 2026
85c9468
Update tests
LPTK Jun 2, 2026
3e6165a
Cleanup
LPTK Jun 2, 2026
a2dbcba
Minor
LPTK Jun 2, 2026
75ecd43
Revert strangely unnecessary changes by Copilot
LPTK Jun 2, 2026
09fe6d3
wat
LPTK Jun 2, 2026
234970f
Fix class constructor symbol refresh
LPTK Jun 2, 2026
5f4cfeb
rewrite progress
AnsonYeung Jun 3, 2026
5a38ed9
Fix BlockTransformer
AnsonYeung Jun 3, 2026
ca726e5
Fix methods
AnsonYeung Jun 3, 2026
23320dd
Fix public fields
AnsonYeung Jun 3, 2026
27d6e45
Fix lifted class
AnsonYeung Jun 3, 2026
602e6a2
Fix a silly bug
AnsonYeung Jun 3, 2026
dcf7cc9
Remove the odd types
AnsonYeung Jun 3, 2026
ca92234
Some final changes and commit tests
AnsonYeung Jun 3, 2026
42af285
Adjust comment
AnsonYeung Jun 3, 2026
bf5da1c
More docs
AnsonYeung Jun 3, 2026
5233b5f
workaround
AnsonYeung Jun 4, 2026
55d0075
Update
AnsonYeung Jun 4, 2026
93de8b7
add todo
AnsonYeung Jun 4, 2026
7bfc6e8
Minor
LPTK Jun 5, 2026
85a5e11
Merge pull request #28 from AnsonYeung/fix-symbol-refresher
LPTK Jun 5, 2026
c509754
Merge branch 'hkmc2' into inliner-lifter-refresher-fixes
LPTK Jun 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object Config:
tailRecOpt = true,
deforest = N,
etaExpansion = S(EtaExpansion.default),
inlining = S(Inliner(10)),
inlining = S(Inliner(default.inlineThreshold)),
deadBranchRemoval = default.deadBranchRemoval,
qqEnabled = false,
funcToCls = false,
Expand All @@ -83,6 +83,7 @@ object Config:
object default:
val patMatConsequentSharingThreshold = S(15)
val deadBranchRemoval = false // TODO
val inlineThreshold = 10

case class SanityChecks(light: Bool, checkUnreachable: Bool)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ class BlockSimplifier
case _ => super.applyBlock(b)

def applyBlock(blk: Block) =
Label(lblSym, false, Copier.applyBlock(blk), _)
Label(lblSym, false, Copier.apply(blk), _)

class Transformer(m: InlinerMap) extends BlockTransformer(SymbolSubst()):

Expand Down
15 changes: 14 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,20 @@ class BlockTransformer(subst: SymbolSubst):
def applyScopedBlock(b: Block): Block = b match
case Scoped(s, bd) =>
val nb = applySubBlock(bd)
if nb is bd then b else Scoped(s, nb)

// Set does not have .mapConserve
var hasDiff = false
val ns = s.map[ScopedSymbol]:
case s: LocalVarSymbol =>
val ns = s.subst
if ns isnt s then hasDiff = true
ns
case s: BlockMemberSymbol =>
val ns = s.subst
if ns isnt s then hasDiff = true
ns

if (nb is bd) && !hasDiff then b else Scoped(ns, nb)
case _ => applySubBlock(b)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise):
val filteredParams = filterFunParams(fDefn.dSym, fDefn.params, instId)
val transformedBody = new Rewriter(instId).rewriteFunBody(fDefn.dSym, fDefn.params, fDefn.body)
val (refreshedParams, refreshParamMap) = makeRefreshedParams(filteredParams)
val bodyWithCorrectSymbols = new RefreshSymbol(refreshParamMap).applyBlock(transformedBody)
val bodyWithCorrectSymbols = new RefreshSymbol(refreshParamMap).apply(transformedBody)
FunDefn(
N, bms, tSym, refreshedParams,
bodyWithCorrectSymbols)(fDefn.configOverride, fDefn.annotations)
Expand Down
419 changes: 138 additions & 281 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala

Large diffs are not rendered by default.

34 changes: 31 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,36 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise):
case Some(bms) =>
k(bms.asMemberRef(l.asMod.get))
case None => super.applyValue(v)(k)
// We may generate bms -> VarSymbol replacement, which require change of IR
case Value.MemberRef(bms, disamb) if existingMapping.contains(bms) =>
val sym = existingMapping(bms)
sym match
case s: VarSymbol => k(Value.SimpleRef(s))
case _ => super.applyValue(v)(k)
case _ => super.applyValue(v)(k)
end RefreshSymbol

// FIXME: This is a temporary workaround for the deforestation problem discussed at
// https://discord.com/channels/884326249617559653/935507764384501760/1512041398520774832
// It should be fixed after we fix Lowering; then, this function should be removed,
// and its uses replaced back by `new RefreshSymbol(refreshParamMap.toMap).apply`
// TODO: it may cause BMS to malfunction if both class and function refer to same BMS and they are split apart
def refreshExtractedBody(mapping: Map[Symbol, Symbol], body: Block): Block =
val defined = MutSet.empty[ScopedSymbol]
val missing = MutSet.empty[ScopedSymbol]
val walker = new BlockTraverserShallow:
override def applyBlock(b: Block): Unit = b match
case Scoped(syms, body) =>
defined.addAll(syms)
applyBlock(body)
case Define(defn: FunDefn, rest) =>
if !defined(defn.sym) then
missing.add(defn.sym)
defined.add(defn.sym)
applyBlock(rest)
case _ => super.applyBlock(b)
walker.applyBlock(body)
new RefreshSymbol(mapping).apply(Scoped(missing, body))

// Instantiated polymorphic functions
val newPolyFuns =
Expand All @@ -507,7 +535,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise):
refreshParamMap(p.sym) = newSym
Param(p.flags, newSym, p.sign, p.modulefulness),
pl.restParam)
val bodyWithCorrectSymbols = new RefreshSymbol(refreshParamMap.toMap).applyBlock(transformedBody)
val bodyWithCorrectSymbols = refreshExtractedBody(refreshParamMap.toMap, transformedBody)
FunDefn(
N, bms, tSym, refreshedParams,
bodyWithCorrectSymbols)(N, PrivateModifier :: fDefn.annotations)
Expand All @@ -524,7 +552,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise):
new Rewriter(instId).applyBlock(ogBody),
Return(mkCall(restFunSym, restFunArgs)))
val refreshedFvSymbols = dtorBranchFnFvs(branchId._1).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}")))
val bodyWithCorrectSymbols = new RefreshSymbol(refreshedFvSymbols.toMap).applyBlock(actualBody)
val bodyWithCorrectSymbols = refreshExtractedBody(refreshedFvSymbols.toMap, actualBody)
FunDefn(N, bms, tSym,
branchFunParamFieldSyms(branchId).asParamList :: refreshedFvSymbols.unzip._2.asParamList :: Nil,
bodyWithCorrectSymbols
Expand All @@ -548,7 +576,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise):
case None =>
Begin(transformedOgBody, Return(Value.Lit(Tree.UnitLit(true))))
val refreshedFvSymbols = restFnFvs(restFunId).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}")))
val bodyWithCorrectSymbols = new RefreshSymbol(refreshedFvSymbols.toMap).applyBlock(actualBody)
val bodyWithCorrectSymbols = refreshExtractedBody(refreshedFvSymbols.toMap, actualBody)
FunDefn(N, bms, tsym, refreshedFvSymbols.unzip._2.asParamList :: Nil, bodyWithCorrectSymbols)(N, annotations = PrivateModifier :: Nil)
end newRestFuns

Expand Down
7 changes: 6 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,11 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder:
doc" # switch (${result(scrut)}) { #{ ${bodWithDflt} #} # }" :: returningTerm(rest, endSemi)
case Match(scrut, arms @ hd :: tl, els, rest) =>
val sd = result(scrut)
// * Parenthesize the scrutinee for property access when it's a numeric literal,
// * since things like `12.length` are invalid JS (the `.` is parsed as a decimal point).
def sdProp = scrut match
case Value.Lit(Tree.IntLit(_) | Tree.DecLit(_)) => doc"($sd)"
case _ => sd
def cond(cse: Case) = cse match
case Case.Lit(lit) => doc"$sd === ${lit.idStr}"
case Case.Cls(cls, pth) => cls match
Expand All @@ -684,7 +689,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder:
// * ^ Note that modules are currently not valid patterns;
// * this case is just for objects, which have their class stored in a `.class` property.
case _ => doc"$sd instanceof ${result(pth)}"
case Case.Tup(len, inf) => doc"$runtimeVar.Tuple.isArrayLike($sd) && $sd.length ${if inf then ">=" else "==="} ${len}"
case Case.Tup(len, inf) => doc"$runtimeVar.Tuple.isArrayLike($sd) && $sdProp.length ${if inf then ">=" else "==="} ${len}"
case Case.Field(name = n, safe = false) =>
doc"""typeof $sd === "object" && $sd !== null && "${n.name}" in $sd"""
case Case.Field(name = n, safe = true) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fun foo =
:sjs
foo
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ return foo2()
//│ Predef.print(1); Predef.print("..."); return runtime.Unit
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ > 1
//│ > ...
Expand Down
10 changes: 5 additions & 5 deletions hkmc2/shared/src/test/mlscript/basics/LazySpreads.mls
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ fun sum2 = case
//│ sum2 = function sum2() {
//│ let lambda;
//│ lambda = (undefined, function (caseScrut) {
//│ let lastElement1$3, middleElements3, element0$, tmp5, tmp6, tmp7;
//│ let lastElement1$3, middleElements3, element0$, tmp4, tmp5, tmp6;
//│ if (runtime.Tuple.isArrayLike(caseScrut) && caseScrut.length === 0) {
//│ return 0
//│ } else if (runtime.Tuple.isArrayLike(caseScrut) && caseScrut.length >= 2) {
//│ element0$ = runtime.Tuple.get(caseScrut, 0);
//│ middleElements3 = runtime.Tuple.slice(caseScrut, 1, 1);
//│ lastElement1$3 = runtime.Tuple.get(caseScrut, -1);
//│ tmp5 = element0$ + lastElement1$3;
//│ tmp6 = sum2();
//│ tmp7 = runtime.safeCall(tmp6(middleElements3));
//│ return tmp5 + tmp7
//│ tmp4 = element0$ + lastElement1$3;
//│ tmp5 = sum2();
//│ tmp6 = runtime.safeCall(tmp5(middleElements3));
//│ return tmp4 + tmp6
//│ }
//│ throw globalThis.Object.freeze(new globalThis.Error("match error"));
//│ });
Expand Down
17 changes: 16 additions & 1 deletion hkmc2/shared/src/test/mlscript/basics/MultiParamListClasses.mls
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,22 @@ Foo(1, 2)
:sjs
let f = Foo // should compile to `(x, y) => Foo(x, y)(use[Int])`
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ let f, f1; f1 = function f(x, y) { return Foo7(x, y)(instance$Ident$_Int$_) }; f = f1;
//│ let f, f1;
//│ f1 = function f(x1, y1) {
//│ let x2, y2, z2, tmp2, lambda4;
//│ x2 = x1;
//│ y2 = y1;
//│ z2 = instance$Ident$_Int$_;
//│ tmp2 = new Foo03();
//│ lambda4 = (undefined, function (res) {
//│ res.x = x2;
//│ res.y = y2;
//│ res.z = z2;
//│ return runtime.Unit
//│ });
//│ return Predef.tap(tmp2, lambda4)
//│ };
//│ f = f1;
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ f = fun f

Expand Down
4 changes: 2 additions & 2 deletions hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fun f(n1: Int)(n2: Int)(n3: Int): Int = 10 * (10 * n1 + n2) + n3

f(4)(2)(0)
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ return f$worker(4, 2, 0)
//│ return 420
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ = 420

Expand All @@ -77,7 +77,7 @@ fun f(n1: Int)(n2: Int)(n3: Int)(n4: Int): Int = 10 * (10 * (10 * n1 + n2) + n3)

f(3)(0)(3)(1)
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ return f$worker1(3, 0, 3, 1)
//│ let tmp, tmp1; tmp = 303; tmp1 = 10 * tmp; return tmp1 + 1
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ = 3031

Expand Down
8 changes: 3 additions & 5 deletions hkmc2/shared/src/test/mlscript/basics/PartialApps.mls
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,17 @@ passTo(1, add(., 1) @ _ + _)(2)
:sjs
let f = add(_, 1)
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ let f, f1;
//│ f1 = function f(_0) { let tmp1; tmp1 = add(); return runtime.safeCall(tmp1(_0, 1)) };
//│ f = f1;
//│ let f, f1; f1 = function f(_0) { return _0 + 1 }; f = f1;
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ f = fun f

:fixme
let f = add(., 1)
//│ ╔══[PARSE ERROR] Expected an expression; found period instead
//│ ║ l.123: let f = add(., 1)
//│ ║ l.121: let f = add(., 1)
//│ ╙── ^
//│ ╔══[PARSE ERROR] Unexpected period here
//│ ║ l.123: let f = add(., 1)
//│ ║ l.121: let f = add(., 1)
//│ ╙── ^
//│ ═══[RUNTIME ERROR] This code cannot be run as its compilation yielded an error.
//│ f = undefined
Expand Down
7 changes: 1 addition & 6 deletions hkmc2/shared/src/test/mlscript/basics/ShortcircuitingOps.mls
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@ false && loudTrue
:sjs
print(1 && loudTrue)
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ if (1 === true) {
//│ let inlinedVal;
//│ inlinedVal = loud(true);
//│ return Predef.print(inlinedVal)
//│ }
//│ return Predef.print(false);
//│ if (1 === true) { Predef.print(true); return Predef.print(true) } return Predef.print(false);
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ > false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Baz(a, b) with
constructor(x, y)(u, v)
fun total() = a + b + x + y + u + v

:expect 21
Baz(1, 2)(3, 4)(5, 6).total()
//│ = 21

Expand All @@ -90,6 +91,7 @@ class Qux(a) with
constructor(extra)
fun both() = a + extra

:expect 133
Qux(10)(123).both()
//│ = 133

Expand Down
16 changes: 8 additions & 8 deletions hkmc2/shared/src/test/mlscript/codegen/ClassInFun.mls
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,23 @@ fun test(x) =
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ let test2;
//│ test2 = function test(x) {
//│ let Foo2, tmp;
//│ Foo2 = function Foo(a, b) {
//│ return globalThis.Object.freeze(new Foo.class(a, b));
//│ let Foo4, tmp;
//│ Foo4 = function Foo(a1, b) {
//│ return globalThis.Object.freeze(new Foo.class(a1, b));
//│ };
//│ (class Foo1 {
//│ (class Foo3 {
//│ static {
//│ Foo2.class = this
//│ Foo4.class = this
//│ }
//│ constructor(a, b) {
//│ this.a = a;
//│ constructor(a1, b) {
//│ this.a = a1;
//│ this.b = b;
//│ }
//│ toString() { return runtime.render(this); }
//│ static [definitionMetadata] = ["class", "Foo", ["a", "b"]];
//│ });
//│ tmp = x + 1;
//│ return Foo2(x, tmp)
//│ return Foo4(x, tmp)
//│ };
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————

Expand Down
30 changes: 16 additions & 14 deletions hkmc2/shared/src/test/mlscript/codegen/ClassMatching.mls
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,16 @@ fun f(x) = if x is
else print("oops")
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ let f4;
//│ f4 = function f(x) {
//│ f4 = function f(x2) {
//│ let scrut1, arg$Some$0$1;
//│ if (x instanceof Some1.class) {
//│ arg$Some$0$1 = x.value;
//│ if (x2 instanceof Some1.class) {
//│ arg$Some$0$1 = x2.value;
//│ scrut1 = arg$Some$0$1 > 0;
//│ if (scrut1 === true) {
//│ return 42
//│ }
//│ return Predef.print("oops");
//│ } else if (x instanceof None1.class) { return "ok" }
//│ } else if (x2 instanceof None1.class) { return "ok" }
//│ return Predef.print("oops");
//│ };
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
Expand All @@ -189,13 +189,13 @@ fun f(x) = if x is
Pair(a, b) then a + b
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ let f5;
//│ f5 = function f(x) {
//│ f5 = function f(x2) {
//│ let arg$Pair$0$, arg$Pair$1$;
//│ if (x instanceof Some1.class) {
//│ return x.value
//│ } else if (x instanceof Pair1.class) {
//│ arg$Pair$0$ = x.fst;
//│ arg$Pair$1$ = x.snd;
//│ if (x2 instanceof Some1.class) {
//│ return x2.value
//│ } else if (x2 instanceof Pair1.class) {
//│ arg$Pair$0$ = x2.fst;
//│ arg$Pair$1$ = x2.snd;
//│ return arg$Pair$0$ + arg$Pair$1$
//│ }
//│ throw globalThis.Object.freeze(new globalThis.Error("match error"));
Expand All @@ -217,15 +217,17 @@ fun f(x) = print of if x is
else "oops"
//│ ————————————| JS (unsanitized) |————————————————————————————————————————————————————————————————————
//│ let f6;
//│ f6 = function f(x) {
//│ f6 = function f(x2) {
//│ let arg$Some$0$1;
//│ if (x instanceof Some1.class) {
//│ arg$Some$0$1 = x.value;
//│ if (x2 instanceof Some1.class) {
//│ arg$Some$0$1 = x2.value;
//│ if (arg$Some$0$1 === 0) {
//│ return Predef.print("0")
//│ }
//│ return Predef.print("oops");
//│ } else if (x instanceof None1.class) { return Predef.print("ok") }
//│ } else if (x2 instanceof None1.class) {
//│ return Predef.print("ok")
//│ }
//│ return Predef.print("oops");
//│ };
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
Expand Down
Loading
Loading