Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ lazy val hkmc2 = crossProject(JSPlatform, JVMPlatform).in(file("hkmc2"))
watchSources += WatchSource(
baseDirectory.value.getParentFile()/"shared"/"src"/"test"/"diff", "*.mls", NothingFilter),

// TODO remove when codebase becomes production-ready
scalacOptions -= "-Wconf:any:error",

// scalacOptions ++= Seq("-indent", "-rewrite"),
scalacOptions ++= Seq("-new-syntax", "-rewrite"),
// scalacOptions ++= Seq("-language:experimental.modularity"), // https://docs.scala-lang.org/scala3/reference/experimental/modularity.html
Expand Down
5 changes: 5 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ sealed abstract class Block extends Product:
case Define(defn, rest) =>
s"""|Define(${defn.showDbg})
|${rest.showDbg}""".stripMargin
case Throw(exc) => s"Throw(${exc.showDbg})"
case Scoped(syms, body) =>
s"""|Scoped(${syms.toList.map(_.showDbg).sorted.mkString(", ")}) {
|${body.showDbg}
|}""".stripMargin

lazy val isAbortive: Bool = this match
case _: End => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise):
(funToSccRepMap(lastRefedSymbol), funToSccRepMap(refSym)) match
case (Some(a), Some(b)) if a is b => instId
case _ => instId :+ refId
case _ => lastWords(s"newRefId: impossible InstantiationId shape $instId")
end newRefId

p match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class EtaExpansionSolver(val constraintSolver: FlowConstraintSolver):
case (fSym: TermSymbol, whichPl: Int) =>
val fnDefn = constraintSolver.collector.preAnalyzer.res.funSymToFunDefn(fSym)
fnDefn.affineInfo.contains(whichPl)
case (sym, _) => lastWords(s"prodFunIsAffine: expected TermSymbol funId, got $sym")
end prodFunIsAffine

res match
Expand Down
11 changes: 5 additions & 6 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter):
setupSelection(prefix, proj, N)(k)
case Resolved(sp @ SelProj(prefix, _, proj), sym) =>
setupSelection(prefix, proj, S(sym))(k)
case Resolved(inner, sym) => TODO(s"lowering for Resolved($inner)")
case Region(reg, body) =>
loweringCtx.collectScopedSym(reg)
Assign(reg, Instantiate(mut = true, Select(State.globalThisSymbol.asThis, Tree.Ident("Region"))(N), Nil :: Nil),
Expand Down Expand Up @@ -1041,7 +1042,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter):
msg"Cannot compile ${t.describe} term that was not elaborated (maybe elaboration was one in 'lightweight' mode?)" ->
t.toLoc :: Nil,
source = Diagnostic.Source.Compilation)
case _: CompType | _: Neg | _: Term.FunTy | _: Term.Forall | _: Term.WildcardTy | _: Term.Unquoted | _: LeadingDotSel
case _: CompType | _: Neg | _: Term.FunTy | _: Term.Forall | _: Term.WildcardTy | _: Term.Unquoted | _: LeadingDotSel | _: Term.Constrained | _: Term.Annotated
=> fail:
ErrorReport(
msg"Unexpected term form in expression position (${t.describe})" ->
Expand Down Expand Up @@ -1089,8 +1090,6 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter):
.chain(b => quote(term)(r2 => Assign(l2, r2, b)))
.chain(b => quoteSplit(tail, splitTmps)(r3 => Assign(l3, r3, b)))
.rest(setupTerm("Let", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef))(k))
case Split.Let(sym, _, _) =>
lastWords(s"tried to quote split let with non-variable symbol ${sym.nme}")
case Split.Else(default) => quote(default): r =>
val l = loweringCtx.registerTempSymbol(N)
Assign(l, r, setupTerm("Else", l.asSimpleRef :: Nil)(k))
Expand All @@ -1101,11 +1100,11 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter):
setupSymbol(tmp): r1 =>
val l1, l2, l3 = loweringCtx.registerTempSymbol(N)
blockBuilder.assign(l1, r1)
.chain(b => Assign(tmp, Value.Ref(l1), b))
.chain(b => Assign(tmp, l1.asSimpleRef, b))
.chain(b => quoteSplit(sym.body, splitTmps2)(r2 => Assign(l2, r2, b)))
.chain(b => quoteSplit(tail, splitTmps2)(r3 => Assign(l3, r3, b)))
.rest(setupTerm("LetSplit", (l1 :: l2 :: l3 :: Nil).map(s => Value.Ref(s)))(k))
case Split.UseSplit(sym) => setupTerm("UseSplit", Value.SimpleRef(splitTmps(sym)) :: Nil)(k)
.rest(setupTerm("LetSplit", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef))(k))
case Split.UseSplit(sym) => setupTerm("UseSplit", splitTmps(sym).asSimpleRef :: Nil)(k)

lazy val setupFilename: Path =
val state = summon[State]
Expand Down
1 change: 0 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config):
case End(msg) if msg.nonEmpty && config.commentGeneratedCode => doc"end /* ${msg} */"
case End(_) => doc"end"
case Unreachable(msg) => doc"unreachable /* ${msg} */"
case _ => TODO(blk)

def printFlags(defn: Defn)(using Scope): Document =
// val overrides = defn match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class DeforestFusionSolver(val constraintSolver: FlowConstraintSolver)(using val
dest match
case FinalDestMatch(dtor, sels) =>
tl.log(s"\tmatch: ${pp(dtor)}")
for s <- sels.toSeq.sortBy(_.exprId) do tl.log(s"\tfields: ${pp(s)}")
for s <- sels.toSeq.sortBy(_.exprId.uid) do tl.log(s"\tfields: ${pp(s)}")
case FinalDestSel(dtors, field) =>
tl.log(s"\tselect: ${pp(field)}")
tl.log("<<< fusing <<<")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise):
finalDest match
case FinalDestSel(dtors, field) =>
// create poly fun syms
val instIds = (dtors + ctor).toList.sortBy(_.exprId).map(_.instId)
val instIds = (dtors + ctor).toList.sortBy(_.exprId.uid).map(_.instId)
for
ctorInstId <- instIds
case path@(pathTo :+ refedFun) <- ctorInstId.inits
Expand All @@ -142,7 +142,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise):

// create branch sel syms
val fieldSym = MutMap.empty[SelField, VarSymbol]
for sel <- sels.toList.sortBy(_._1) do
for sel <- sels.toList.sortBy(_._1.uid) do
branchSelSyms.getOrElseUpdate(
sel,
locally:
Expand Down Expand Up @@ -357,7 +357,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise):

def mkCall(target: (BlockMemberSymbol, TermSymbol), args: Ls[ValueSymbol]): Call =
Call(
Value.Ref(target._1, S(target._2)),
Value.MemberRef(target._1, target._2),
args.map(a => Arg(N, a.asPath)) ne_:: Nil
)(true, false, false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ object FlowAnalysis:
val AccumulatorSym = "flow-analysis/accumulator"

class State:
val resultToResultId = new java.util.IdentityHashMap[Result, Uid[Result]].asScala
val resultIdToResult = mutable.Map.empty[Uid[Result], Result]
val resultToResultId = new java.util.IdentityHashMap[Result, ResultId].asScala
val resultIdToResult = mutable.Map.empty[ResultId, Result]
val stratVarIdToState = mutable.Map.empty[StratVarId, StratVarState]
object ResultUidState extends Uid.Result.State

Expand All @@ -46,7 +46,7 @@ object FlowAnalysis:
extension (r: Result)
def uid = resultToResultId.get(r) match
case None =>
val id = ResultUidState.nextUid
val id = ResultId(ResultUidState.nextUid)
resultIdToResult(id) = r
resultToResultId(r) = id
id
Expand Down Expand Up @@ -81,11 +81,13 @@ object FlowAnalysis:
end FlowAnalysis


type ResultId = Uid[Result]
class ResultId(val uid: Uid[Result]):
override def toString: String = uid.toString

type InstantiationId = Ls[ResultId]
type CtorCls = ClassLikeSymbol | Int
type SelField = TermSymbol | Int
type FunId = (funSym: Symbol, whichParamList: Int) | ResultId
type FunId = (funSym: TermSymbol, whichParamList: Int) | ResultId
type OriginId = ResultId | FunId


Expand Down Expand Up @@ -255,7 +257,7 @@ class Dtor(

case class ConcreteId[A <: OriginId](exprId: A, instId: InstantiationId):
def pp(using FlowAnalysis.State): Str = exprId match
case (sym: Symbol, idx: Int) => s"${sym.nme}#$idx"
case (sym: TermSymbol, idx: Int) => s"${sym.nme}#$idx"
case r: ResultId => s"${r.getResult}"


Expand Down Expand Up @@ -809,10 +811,11 @@ class FlowConstraintsCollector(
)(using cc: ConstraintsCollector): ProdStrat =
def paramListFunId(whichParamList: Int): FunId =
funLamId match
case (sym: Symbol, -1) => (sym, whichParamList)
case (sym: TermSymbol, -1) => (sym, whichParamList)
case lambdaExprId: ResultId =>
assert(whichParamList == 0)
lambdaExprId
case other => lastWords(s"unexpected funLamId shape: $other")
val capturedSyms = funLamId match
case (sym: TermSymbol, _) => preAnalyzer.res.capturedVars(sym)
case lamExprId: ResultId => preAnalyzer.res.capturedVars(lamExprId)
Expand Down Expand Up @@ -981,7 +984,6 @@ class FlowConstraintsCollector(
rest.foreach: nextArgs =>
nextArgs.foreach(a => cc.constrain(processResult(a.value), UnknownCons))
UnknownProd
case Nil => handleCallLike(c.uid, fun, Nil)
case i@Instantiate(_, cls, argss) => handleCallLike(i.uid, cls, argss.flatten)
case lam@Lambda(ps, body) =>
mkFunProdStrat("lam_res", ps :: Nil, body, lam.uid)
Expand Down
14 changes: 13 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case FunTy(Term.Tup(params), ret, eff) =>
PolyFunType(params.map {
case Fld(_, p, _) => typeAndSubstType(p, !pol)
case spd: Spd => lastWords(s"unexpected spread in function type parameters: $spd")
}, typeAndSubstType(ret, pol), eff.map(e => typeAndSubstType(e, pol) match {
case t: Type => t
case _ => error(msg"Effect cannot be polymorphic." -> ty.toLoc :: Nil)
Expand Down Expand Up @@ -180,7 +181,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case _ =>
ty.symbol.flatMap(_.asTpe) match
case S(cls: (ClassSymbol | TypeAliasSymbol)) => typeAndSubstType(Term.TyApp(ty, Nil)(N), pol)
case N => error(msg"Invalid type" -> ty.toLoc :: Nil, S(ty)) // TODO
case S(_) | N => error(msg"Invalid type" -> ty.toLoc :: Nil, S(ty)) // TODO

private def genPolyType(tvs: Ls[QuantVar], outer: InfVar, body: => GeneralType)(using ctx: InvalCtx, cctx: CCtx) =
val bds = tvs.map:
Expand Down Expand Up @@ -248,6 +249,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case (res, p: Fld) =>
val (ty, ctx, eff) = typeCode(p.term)
(ty :: res._1, res._2 | ctx, res._3 | eff)
case (_, spd: Spd) => TODO(s"spread arguments in quoted code: $spd")
val resTy = freshVar(new TempSymbol(S(app), "app"))
constrain(lhsTy, FunType(rhsTy.reverse, resTy, Bot)) // TODO: right
(resTy, lhsCtx | rhsCtx, lhsEff | rhsEff)
Expand Down Expand Up @@ -312,6 +314,9 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case L(N) => rec(alts, L(S(sym)))
case L(S(other)) if adtParent.get(other.uid).exists(p => p.uid == adtParent(sym.uid).uid) =>
rec(alts, L(S(sym)))
case L(S(_)) =>
error(msg"Matching patterns from different ADTs in one match is not supported." -> split.toLoc :: Nil)
false
case R(_) =>
error(msg"Mixing ADT pattern matching and general matching is not supported yet." -> split.toLoc :: Nil)
false
Expand Down Expand Up @@ -373,10 +378,13 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
params.iterator.zip(paramList).foreach:
case (p, Param(_, _, S(ty), _)) =>
nestCtx += p._1 -> typeAndSubstType(ty, true)(using map.toMap)
case (_, p) =>
error(msg"Invalid ADT parameter." -> p.toLoc :: Nil)
val (consTy, consEff) = typeAllSplits(cons, sign)(using nestCtx)
val (altsTy, altsEff, altCases, fallback) = typeADTMatch(alts, sign)
val allEff = scrutineeEff | (consEff | altsEff)
(sign.getOrElse(tryMkMono(consTy, cons) | tryMkMono(altsTy, alts)), allEff, sym :: altCases, fallback)
case _ => lastWords(s"unexpected non-ClassLike pattern in ADT match: $pattern")
case Split.Let(name, term, tail) =>
val nestCtx = ctx.nest
given InvalCtx = nestCtx
Expand Down Expand Up @@ -428,6 +436,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case _: Tree.DecLit => InvalCtx.numTy
case _: Tree.StrLit => InvalCtx.strTy
case _: Tree.UnitLit => InvalCtx.unitTy
case (_: FlatPattern.Tuple) | (_: FlatPattern.Record) => TODO(s"tuple/record patterns in invalml split typing: $pattern")
constrain(tryMkMono(scrutineeTy, scrutinee), patTy)
val (consTy, consEff) = typeSplit(cons, sign)(using nestCtx1)
val (altsTy, altsEff) = typeSplit(alts, sign)(using nestCtx2)
Expand Down Expand Up @@ -523,6 +532,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case (f: Fld, t) =>
val (ty, ef) = ascribe(f.term, t)
resEff |= ef
case (spd: Spd, _) => TODO(s"spread arguments: $spd")
(ret, resEff)
case (FunType(params, ret, eff), lhsEff) => app((PolyFunType(params, ret, eff), lhsEff), rhs, t)
case (ty: PolyType, eff) => app((instantiate(ty), eff), rhs, t)
Expand All @@ -531,6 +541,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case f: Fld =>
val (ty, eff) = typeCheck(f.term)
Left(ty) :: Right(eff) :: Nil
case spd: Spd => TODO(s"spread arguments: $spd")
.partitionMap(x => x)
val effVar = freshVar(new TempSymbol(S(t), "eff"))
val retVar = freshVar(new TempSymbol(S(t), "app"))
Expand Down Expand Up @@ -586,6 +597,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case Wildcard(in: InfVar, _) => in :: Nil
case Wildcard(_, out: InfVar) => out :: Nil
case v: InfVar => v :: Nil
case other => lastWords(s"unexpected non-InfVar type argument in ADT ctor: $other")
}, N, PolyFunType(clsDef.params.params.map {
case Param(_, _, S(ty), _) => typeAndSubstType(ty, true)(using map.toMap)
case p =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package hkmc2.invalml

import scala.collection.mutable.{Map => MutMap, Set => MutSet, LinkedHashMap, LinkedHashSet}
import scala.collection.immutable.{SortedMap, SortedSet}
import scala.util.boundary, boundary.break
import scala.util.chaining._

import mlscript.utils.*, shorthands.*
Expand Down Expand Up @@ -73,14 +74,14 @@ class TypeSimplifier(tl: TraceLogger):
case N => tv
}

override def apply(pol: Bool)(ty: GeneralType): Unit =
override def apply(pol: Bool)(ty: GeneralType): Unit = boundary:
trace(s"Analyse[${printPol(pol)}] ${ty.showDbg} [${curPath.reverseIterator.mkString(" ~> ")}]"):
ty match
case ty if ty.lvl <= lvl =>
log(s"Level is < $lvl")
case tv: IV if { occsNum(tv) = occsNum.getOrElse(tv, 0) + 1; false } =>
case tv: IV =>
if varSubst.contains(tv) then return log(s"Already subst'd") // * If the IV was set to be substituted, it means it's been found recursive and we don't need to traverse it again
if varSubst.contains(tv) then break(log(s"Already subst'd")) // * If the IV was set to be substituted, it means it's been found recursive and we don't need to traverse it again
var continue = true
// if (!traversedTVs.contains(tv)) {
if curPath.exists(_ is tv) then // TODO opt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ trait BlockImpl(using Elaborator.State):
lazy val (headId, headPs) = td.baseHead match
case id: Ident => (id, Nil)
case App(id: Ident, TyTup(ps)) => (id, ps)
case _ => TODO(s"unexpected ADT head shape: ${td.baseHead}")
// Temporarily use `data` annotation to distinguish the following ctors:
// - Ctor(...)
// - Ctor[...](...) extends ADT[...]
Expand Down
Loading
Loading