Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
477 changes: 229 additions & 248 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package flowAnalysis
import scala.jdk.CollectionConverters.MapHasAsScala
import utils.*
import mlscript.utils.*, shorthands.*
import hkmc2.Message.MessageContext
import semantics.*
import syntax.Tree
import scala.collection.mutable
Expand Down Expand Up @@ -87,18 +88,6 @@ type SelField = TermSymbol | Int
type FunId = (funSym: Symbol, whichParamList: Int) | ResultId
type OriginId = ResultId | FunId

/** Extracts the underlying symbol of a variable-like reference, for flow-tracking use. */
object TrackedSymOf:
def unapply(p: Value.RefLike | Select)(using Elaborator.State): Opt[Symbol] = p match
case Value.SimpleRef(sym) => S(sym)
case Value.MemberRef(_, disamb) => S(disamb)
case Value.This(sym) => S(sym)
case s: Select => s.symbol.flatMap: selSym =>
for
selTermSym <- selSym.asTrm
owner <- selTermSym.owner
_ <- owner.asMod
yield selTermSym

object TrackableFieldSelect:
def unapply(s: Select): Opt[Path -> (field: TermSymbol, owner: ClassSymbol)] =
Expand Down Expand Up @@ -133,38 +122,24 @@ object TrackableSelect:
Some((qual, field, owner))
case _ => N

object CtorRef:
/** Resolves an object reference or a class-ctor `TermSymbol` to its corresponding class/object symbol. */
private def classCtorSymbol(sym: Symbol)(using Elaborator.State): Opt[ClassSymbol | ModuleOrObjectSymbol] =
sym.asObj orElse
sym.asTrm.flatMap: tSym =>
for
cls <- tSym.owner.flatMap(_.asCls)
clsDef <- cls.irClsLikeDefn
ctorSym <- clsDef.ctorSym
if ctorSym is tSym
yield cls

def unapply(p: Path)(using Elaborator.State): Opt[ClassSymbol | ModuleOrObjectSymbol] = p match
case Value.SimpleRef(sym) => classCtorSymbol(sym)
case Value.MemberRef(_, disamb) => classCtorSymbol(disamb) orElse disamb.asCls orElse disamb.asObj
case Value.This(sym) => classCtorSymbol(sym) orElse sym.asCls
case s: Select => s.symbol.flatMap(classCtorSymbol)
case _ => N
object MemberRefTo:
def unapply(p: Path) = p.targetSymbol

object CtorCall:
def unapply(r: Result)(using Elaborator.State): Option[(ClassSymbol | ModuleOrObjectSymbol | Int) -> Ls[Arg]] =
object CtorProducer:
/** Extracts result forms that produce a concrete `Ctor` flow strategy. */
def unapply(r: Result)(using Elaborator.State): Opt[CtorCls -> Ls[Arg]] =
r match
case Instantiate(_, CtorRef(ctor), argss) => Some(ctor -> argss.flatten)
case Call(CtorRef(ctor), argss) => Some(ctor -> argss.flatten)
case CtorRef(ctor) if ctor.asObj.isDefined => Some(ctor -> Nil)
case Tuple(_, args) => Some(args.size, args)
case _ => None
case Instantiate(_, MemberRefTo(cls: ClassSymbol), argss) => S(cls -> argss.flatten)
case Call(MemberRefTo(cls: ClassCtorSymbol), argss) => S(cls.owner.get -> argss.flatten)
case MemberRefTo(ctor: ModuleOrObjectSymbol) => S(ctor -> Nil)
case Tuple(_, args) => S(args.size, args)
case _ => N

object FunRef:
def unapply(s: Path)(using Elaborator.State): Option[TermSymbol] = s match
case TrackedSymOf(tSym: TermSymbol) if tSym.k is syntax.Fun => Some(tSym)
case _ => None
case MemberRefTo(tSym: TermSymbol)
if (tSym.k is syntax.Fun) && tSym.owner.forall(_.asMod.isDefined) => S(tSym)
case _ => N

type StratVarId = Uid[StratVar]

Expand Down Expand Up @@ -957,7 +932,7 @@ class FlowConstraintsCollector(
fromStrat,
new FieldSel(sel.uid, instId)(field, owner, selRes.asConsStrat))
selRes.asProdStrat
case c@CtorCall(ctor, args) if args.forall(_.spread.isEmpty) =>
case c@CtorProducer(ctor, args) if args.forall(_.spread.isEmpty) =>
val argsStrat = args.map:
case Arg(_, a) => processResult(a)
ctor match
Expand All @@ -970,14 +945,14 @@ class FlowConstraintsCollector(
new Ctor(c.uid, instId)(ctor, clsParams.zip(argsStrat))
case _ =>
// - the size of 0 means we don't know the cls param symbols,
// so we constrain args with NoCons and this CtorCall gives NoProd
// so we constrain args with NoCons and this CtorProducer gives NoProd
// - if size > 1, we cannot handle multiple parameter class flow now,
// constrain args with NoCons and this CtorCall gives NoProd
// constrain args with NoCons and this CtorProducer gives NoProd
for a <- argsStrat do cc.constrain(a, UnknownCons)
UnknownProd
case _: ModuleOrObjectSymbol => new Ctor(c.uid, instId)(ctor, Nil)
case tupSize: Int => new Ctor(c.uid, instId)(tupSize, (0 until tupSize).zip(argsStrat).toList)
case c@CtorCall(_, args) =>
case c@CtorProducer(_, args) =>
args.foreach(arg => cc.constrain(processResult(arg.value), UnknownCons))
UnknownProd
case c@Call(fun, argss) =>
Expand All @@ -998,7 +973,7 @@ class FlowConstraintsCollector(
case i@Instantiate(_, cls, argss) => handleCallLike(i.uid, cls, argss.flatten)
case lam@Lambda(ps, body) =>
mkFunProdStrat("lam_res", ps :: Nil, body, lam.uid)
case _: Tuple => lastWords("should be handled in CtorCall")
case _: Tuple => lastWords("should be handled in CtorProducer")
case Record(_, fields) =>
fields.foreach:
case RcdArg(idx, value) =>
Expand All @@ -1007,25 +982,22 @@ class FlowConstraintsCollector(
UnknownProd
case p: Path =>
p match
case CtorRef(ctor) => UnknownProd
case refSite@FunRef(f) =>
funsToProdStratScheme.get(f) match
case Some(fScheme) =>
fScheme.instantiate(refSite.uid, f)
case None => generatedProdVars(f).asProdStrat
case refLk@TrackedSymOf(sym) =>
refLk match
case Select(p, _) => cc.constrain(processResult(p), UnknownCons)
case _ => ()
generatedProdVars(sym).asProdStrat
case _: Value.RefLike => lastWords("already handled in `TrackedSymOf` case")
case Select(qual, name) =>
case s@Select(qual, name) =>
cc.constrain(processResult(qual), UnknownCons)
UnknownProd
s.symbol.fold(UnknownProd): selSym =>
generatedProdVars(selSym).asProdStrat
case DynSelect(qual, fld, arrayIdx) =>
cc.constrain(processResult(qual), UnknownCons)
cc.constrain(processResult(fld), UnknownCons)
UnknownProd
case Value.MemberRef(_, disamb) => generatedProdVars(disamb).asProdStrat
case Value.SimpleRef(sym) => generatedProdVars(sym).asProdStrat
case Value.This(_) => UnknownProd
case Value.Lit(lit) => UnknownProd
}
end FlowConstraintsCollector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder:
case Elaborator.ctx.builtins.Bool => doc"typeof $sd === 'boolean'"
case Elaborator.ctx.builtins.Int => doc"globalThis.Number.isInteger($sd)"
case Elaborator.ctx.builtins.BigInt => doc"typeof $sd === 'bigint'"
case Elaborator.ctx.builtins.Symbol.module => doc"typeof $sd === 'symbol'"
case Elaborator.ctx.builtins.Symbol => doc"typeof $sd === 'symbol'"
case Elaborator.ctx.builtins.TypedArray =>
doc"globalThis.ArrayBuffer.isView($sd) && !($sd instanceof globalThis.DataView)"
case _: ModuleOrObjectSymbol => doc"$sd instanceof ${result(pth)}.class"
Expand Down
3 changes: 2 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ object Elaborator:
val Object = assumeBuiltinCls("Object")
val Array = assumeBuiltinCls("Array")
val TypedArray = assumeBuiltinCls("TypedArray")
val Symbol = assumeBuiltinCls("Symbol")
// println(s"Builtins: $Int, $Num, $Str, $untyped")
class VirtualModule(val module: ModuleOrObjectSymbol):
val bms = getBuiltin(module.nme) match
Expand All @@ -240,7 +241,7 @@ object Elaborator:
module.tree.definedSymbols.get(nme).getOrElse:
throw new NoSuchElementException(
s"builtin module symbol source.$nme")
object Symbol extends VirtualModule(assumeBuiltinObj("Symbol")):
object SymbolModule extends VirtualModule(assumeBuiltinMod("Symbol")):
val `for` = assumeObject("for")
val iterator = assumeObject("iterator")
object source extends VirtualModule(assumeBuiltinMod("source")):
Expand Down
12 changes: 10 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,16 @@ enum Annot extends AutoLocated:
case TailCall
case Inline
case Config(modify: hkmc2.Config => hkmc2.Config)
// marks if a function or lambda is affine, i.e. called at most once.
// for functions with multiple parameter lists, `whichParamList` is the zero-based parameter-list index.
// Marks if a function or lambda is one-shot, i.e. called at most once.
// Functions with multiple parameter lists are considered here as a chain of
// function values. `whichParamList` is the zero-based index of the parameter
// list whose corresponding function value is one-shot.
// For example, on `fun f(a)(b)`,
// - its list of annotations containing `Affine(0)` says that `f` is one-shot;
// - its list of annotations containing `Affine(1)` says that
// each function value produced by `f(a)` is one-shot;
// - its list of annotations containing both `Affine(0)` and `Affine(1)` says that
// `f` is one-shot and each function value produced by `f(a)` is also one-shot.
case Affine(whichParamList: Int)

def symbol: Opt[Symbol] = this match
Expand Down
12 changes: 6 additions & 6 deletions hkmc2/shared/src/test/mlscript/codegen/ConfigDirective.mls
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,12 @@ fun fst(p) = if p is Pair(x, y) then x
fun applyTwice(f, x) = f(x) + f(x)
applyTwice(fst, Pair(1, 2))
//│ deforest > >>> non-affine syms >>>
//│ deforest > cap_ub_(term:fst,0)_for_fst@36
//│ deforest > f_for_applyTwice@25
//│ deforest > f_for_applyTwice@58
//│ deforest > fst_for_fst@33
//│ deforest > cap_ub_(term:fst,0)_for_fst@35
//│ deforest > f_for_applyTwice@24
//│ deforest > f_for_applyTwice@55
//│ deforest > fst_for_fst@32
//│ deforest > tmp@1
//│ deforest > x_for_applyTwice@24
//│ deforest > x_for_applyTwice@57
//│ deforest > x_for_applyTwice@23
//│ deforest > x_for_applyTwice@54
//│ deforest > <<< non-affine syms <<<
//│ = 2
8 changes: 5 additions & 3 deletions hkmc2/shared/src/test/mlscript/decls/Prelude.mls
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ declare module Array with
isArray
prototype

declare object Symbol with
declare class Symbol
declare module Symbol with
// The `TermDef` needs `rhs` to be defined to be recognized as `isMLsFun`.
// Otherwise, it would be wrapped in `runtime.safeCall`, which accesses the
// uninitialized `runtime` in the `Rendering` module.
Expand Down Expand Up @@ -211,9 +212,10 @@ declare val Infinity
declare class Promise
declare val Promise
declare object WebAssembly with
declare object Instance
declare class Instance
declare class Memory
declare object Module with
declare class Module
declare module Module with
fun
exports
imports
Expand Down
26 changes: 13 additions & 13 deletions hkmc2/shared/src/test/mlscript/deforest/basic.mls
Original file line number Diff line number Diff line change
Expand Up @@ -76,38 +76,38 @@ f(p())
//│ deforest > fields: arg$X$0$⁰.b¹
//│ deforest > <<< fusing <<<
//│ ———————————————| Lowered IR |———————————————————————————————————————————————————————————————————————
//│ let p⁰, f⁰, tmp, p_10$p⁰, f_12$f⁰, f_12$x_X⁰, f_12$x_rest⁰, f_12$arg$X$0$_Y⁰, f_12$arg$X$0$_rest⁰;
//│ let p⁰, f⁰, tmp, p_8$p⁰, f_10$f⁰, f_10$x_X⁰, f_10$x_rest⁰, f_10$arg$X$0$_Y⁰, f_10$arg$X$0$_rest⁰;
//│ @private
//│ define p_10$p⁰ as fun p_10$p¹()(eta$0$0, eta$0$1, eta$0$2) {
//│ define p_8$p⁰ as fun p_8$p¹()(eta$0$0, eta$0$1, eta$0$2) {
//│ let tmp1, Y_b, X_a;
//│ set Y_b = O⁰;
//│ set tmp1 = f_12$arg$X$0$_Y¹(Y_b);
//│ set tmp1 = f_10$arg$X$0$_Y¹(Y_b);
//│ set X_a = tmp1;
//│ return f_12$x_X¹(X_a)(eta$0$0, eta$0$1, eta$0$2)
//│ return f_10$x_X¹(X_a)(eta$0$0, eta$0$1, eta$0$2)
//│ };
//│ @private
//│ define f_12$f⁰ as fun f_12$f¹(x) {
//│ define f_10$f⁰ as fun f_10$f¹(x) {
//│ let y, arg$X$0$, arg$Y$0$;
//│ return x(y, arg$X$0$, arg$Y$0$)
//│ };
//│ @affine(1) @private
//│ define f_12$x_X⁰ as fun f_12$x_X¹(X_a)(fv_y, fv_arg$X$0$, fv_arg$Y$0$) {
//│ define f_10$x_X⁰ as fun f_10$x_X¹(X_a)(fv_y, fv_arg$X$0$, fv_arg$Y$0$) {
//│ set fv_arg$X$0$ = X_a;
//│ return fv_arg$X$0$(fv_y, fv_arg$Y$0$)
//│ };
//│ @affine(1) @private
//│ define f_12$arg$X$0$_Y⁰ as fun f_12$arg$X$0$_Y¹(Y_b)(fv_y, fv_arg$Y$0$) {
//│ define f_10$arg$X$0$_Y⁰ as fun f_10$arg$X$0$_Y¹(Y_b)(fv_y, fv_arg$Y$0$) {
//│ set fv_arg$Y$0$ = Y_b;
//│ set fv_y = fv_arg$Y$0$;
//│ return fv_y
//│ };
//│ @private
//│ define f_12$x_rest⁰ as fun f_12$x_rest¹() {
//│ define f_10$x_rest⁰ as fun f_10$x_rest¹() {
//│ return null
//│ };
//│ @private
//│ define f_12$arg$X$0$_rest⁰ as fun f_12$arg$X$0$_rest¹() {
//│ return f_12$x_rest¹()
//│ define f_10$arg$X$0$_rest⁰ as fun f_10$arg$X$0$_rest¹() {
//│ return f_10$x_rest¹()
//│ };
//│ define p⁰ as fun p¹() {
//│ let tmp1;
Expand All @@ -131,8 +131,8 @@ f(p())
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ };
//│ set tmp = p_10$p¹();
//│ return f_12$f¹(tmp)
//│ set tmp = p_8$p¹();
//│ return f_10$f¹(tmp)
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ = O

Expand Down Expand Up @@ -192,7 +192,7 @@ fun c(x, d) = if x is
//│ deforest > match: x⁴
//│ deforest > fields: x⁴.a¹
//│ deforest > <<< fusing <<<
//│ = fun lambda_8$lambda
//│ = fun lambda_7$lambda


:expect 2
Expand Down
16 changes: 8 additions & 8 deletions hkmc2/shared/src/test/mlscript/deforest/cyclic.mls
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ modSome(Some(17), 3)
//│ deforest > fields: x⁰.x¹
//│ deforest > <<< fusing <<<
//│ ———————————————| Lowered IR |———————————————————————————————————————————————————————————————————————
//│ let modSome⁰, tmp, modSome_12$modSome⁰, modSome_12$x_Some⁰, modSome_12$x_rest⁰, Some_x;
//│ let modSome⁰, tmp, modSome_11$modSome⁰, modSome_11$x_Some⁰, modSome_11$x_rest⁰, Some_x;
//│ @private
//│ define modSome_12$modSome⁰ as fun modSome_12$modSome¹(x, m) {
//│ define modSome_11$modSome⁰ as fun modSome_11$modSome¹(x, m) {
//│ let n, scrut, arg$Some$0$, tmp1, tmp2;
//│ return x(m, n, scrut, arg$Some$0$, tmp1, tmp2)
//│ };
//│ @affine(1) @private
//│ define modSome_12$x_Some⁰ as fun modSome_12$x_Some¹(Some_x1)(fv_m, fv_n, fv_scrut, fv_arg$Some$0$, fv_tmp, fv_tmp1) {
//│ define modSome_11$x_Some⁰ as fun modSome_11$x_Some¹(Some_x1)(fv_m, fv_n, fv_scrut, fv_arg$Some$0$, fv_tmp, fv_tmp1) {
//│ set fv_arg$Some$0$ = Some_x1;
//│ set fv_n = fv_arg$Some$0$;
//│ set fv_scrut = >=⁰(fv_n, fv_m);
Expand All @@ -100,14 +100,14 @@ modSome(Some(17), 3)
//│ let Some_x2;
//│ set fv_tmp = -⁰(fv_n, fv_m);
//│ set Some_x2 = fv_tmp;
//│ set fv_tmp1 = modSome_12$x_Some¹(Some_x2);
//│ return modSome_12$modSome¹(fv_tmp1, fv_m)
//│ set fv_tmp1 = modSome_11$x_Some¹(Some_x2);
//│ return modSome_11$modSome¹(fv_tmp1, fv_m)
//│ else
//│ return Some⁰(fv_n)
//│ end
//│ };
//│ @private
//│ define modSome_12$x_rest⁰ as fun modSome_12$x_rest¹() {
//│ define modSome_11$x_rest⁰ as fun modSome_11$x_rest¹() {
//│ return null
//│ };
//│ define modSome⁰ as fun modSome¹(x, m) {
Expand Down Expand Up @@ -135,8 +135,8 @@ modSome(Some(17), 3)
//│ end
//│ };
//│ set Some_x = 17;
//│ set tmp = modSome_12$x_Some¹(Some_x);
//│ return modSome_12$modSome¹(tmp, 3)
//│ set tmp = modSome_11$x_Some¹(Some_x);
//│ return modSome_11$modSome¹(tmp, 3)
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ = Some(2)

Expand Down
Loading
Loading