Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions hkmc2/js/src/test/scala/hkmc2/CompilerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import scala.scalajs.js.annotation._
import scala.scalajs.js.Dynamic.global

class CompilerTest extends AnyFunSuite:
private def hasErrors(diagnostics: js.Array[js.Dynamic]): Bool =
diagnostics.exists: perFile =>
val fileDiagnostics = perFile.diagnostics.asInstanceOf[js.Array[js.Dynamic]]
fileDiagnostics.exists(_.kind is "error")

private def loadStandardLibrary(): Map[String, String] =
val projectRoot = node.process.cwd()
val compilePath = node.path.join(projectRoot, "hkmc2", "shared", "src", "test", "mlscript-compile")
Expand Down Expand Up @@ -57,10 +62,7 @@ class CompilerTest extends AnyFunSuite:

val diagnostics = compiler.compile(inputPath)

val hasErrors = diagnostics.exists: perFile =>
val fileDiagnostics = perFile.diagnostics.asInstanceOf[scala.scalajs.js.Array[scala.scalajs.js.Dynamic]]
fileDiagnostics.exists(_.kind is "error")
assert(!hasErrors, "Compilation should succeed without errors")
assert(!hasErrors(diagnostics), "Compilation should succeed without errors")

val outputExists = fs.exists(Path(outputPath))
assert(outputExists, "Output JavaScript file should be generated")
Expand Down Expand Up @@ -92,6 +94,32 @@ class CompilerTest extends AnyFunSuite:

assert(fs.exists(Path("/Foo.mjs")), "First output should exist")
assert(fs.exists(Path("/Bar.mjs")), "Second output should exist")

test("compiler emits wasm artifacts for a wasm-targeted program"):
val (fs, compiler) = createCompiler()

fs.write("/simpleWasm.mls",
"""|#config(target: CompilationTarget.Wasm)
|
|40 + 2
|""".stripMargin)

val diagnostics = compiler.compile("/simpleWasm.mls")

assert(!hasErrors(diagnostics), "Compilation should succeed without errors")

assert(fs.exists(Path("/simpleWasm.mjs")), "Glue JavaScript file should be generated")
assert(fs.exists(Path("/simpleWasm.wat")), "WAT file should be generated")
assert(fs.exists(Path("/RuntimeWASM.mjs")), "Shared Wasm runtime helper should be generated")
assert(fs.exists(Path("/RuntimeWASM.wat")), "Shared Wasm intrinsic WAT should be generated")

val glue = fs.read("/simpleWasm.mjs")
assert(glue.contains("__mlx_compileWatFromUrl"), "Glue code should load module WAT from file")
assert(glue.contains("export const __mlx_wasm"), "Glue code should expose the internal wasm loader")
assert(glue.contains("export default"), "Glue code should export the module result by default")

val wat = fs.read("/simpleWasm.wat")
assert(wat.contains("(module"), "Generated WAT should contain a module")

test("compiler can report errors"):
val (fs, compiler) = createCompiler()
Expand All @@ -102,8 +130,4 @@ class CompilerTest extends AnyFunSuite:

assert(diagnostics.length is 1, "Should report diagnostics for only one file")

val hasErrors = diagnostics.exists: perFile =>
val fileDiagnostics = perFile.diagnostics.asInstanceOf[scala.scalajs.js.Array[scala.scalajs.js.Dynamic]]
fileDiagnostics.exists(_.kind is "error")

assert(hasErrors, "Compilation should report errors")
assert(hasErrors(diagnostics), "Compilation should report errors")
29 changes: 22 additions & 7 deletions hkmc2/shared/src/main/scala/hkmc2/CompilerCtx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import mlscript.utils.*, shorthands.*
import hkmc2.utils.*
import hkmc2.Message.MessageContext
import hkmc2.io
import utils.TraceLogger
import utils.TL

import semantics.*
import Elaborator.*
Expand All @@ -34,7 +34,7 @@ class CompilerCtx(
CompilerCtx(S(newFile, this), beingCompiled + newFile, fs, cache)

def getElaboratedBlock
(file: io.Path, prelude: Ctx)
(file: io.Path, prelude: Ctx, full: Bool)
(using TL, Raise, Config)
: Artifact =

Expand Down Expand Up @@ -62,18 +62,28 @@ class CompilerCtx(
given CompilerCtx = this
ParserSetup(file, dbgParsing = false)
val resBlk = parse.resultBlk
given Elaborator.Ctx = prelude.copy(mode = Mode.Light).nestLocal("prelude")
val mode = if full then Mode.Full else Mode.Light
given Elaborator.Ctx = prelude.copy(mode = mode).nestLocal("prelude")
val elab =
given CompilerCtx = derive(parse.origin.fileName)
Elaborator(tl, file.up, prelude)
val elabbed = elab.importFrom(resBlk)
Artifact(resBlk, elabbed._1, lastMod)
val (blk, _) = elab.importFrom(resBlk)
// Resolve once here so cached terms are not mutated again by later passes
// (diamond imports would otherwise re-run Resolver and trip expand()).
Resolver(summon[TL])(using summon[Raise], summon[State], summon[Ctx], summon[Config])
.traverseBlock(blk)(using Resolver.ICtx.empty)
Artifact(resBlk, blk, lastMod, full)

cache.upsert(file):
case N => mk
case cur @ S(art) =>
if art.lastChangedTimestamp < lastMod then mk
if art.lastChangedTimestamp < lastMod || !art.mode && full then mk
else art

def getCachedElaboratedBlock(file: io.Path, full: Bool): Opt[Artifact] =
val lastMod = fs.getLastChangedTimestamp(file)
cache.elabCache.get(file).filter: art =>
art.lastChangedTimestamp >= lastMod && (art.mode || !full)


object CompilerCtx:
Expand All @@ -88,7 +98,12 @@ end CompilerCtx

object CompilerCache:

class Artifact(val tree: syntax.Tree.Block, val term: semantics.Term.Blk, val lastChangedTimestamp: Long)
class Artifact(
val tree: syntax.Tree.Block,
val term: semantics.Term.Blk,
val lastChangedTimestamp: Long,
val mode: Bool,
)

end CompilerCache

Expand Down
15 changes: 15 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ object ConfigParser:
source = Diagnostic.Source.Compilation))
N

private def parseCompilationTarget(tree: Tree)(using Raise): Opt[CompilationTarget] =
tree.deparenthesized match
case Sel(Ident("CompilationTarget"), Ident("JS")) | Ident("JS") =>
S(CompilationTarget.JS)
case Sel(Ident("CompilationTarget"), Ident("Wasm")) | Ident("Wasm") =>
S(CompilationTarget.Wasm)
case _ =>
raise(ErrorReport(
msg"Expected CompilationTarget.JS or CompilationTarget.Wasm" -> tree.toLoc :: Nil,
source = Diagnostic.Source.Compilation))
N

/** Parse the `None`/`Some(...)` syntax for optional config fields.
* Also accepts unwrapped values as a convenience (treated as `Some(value)`). */
private def parseOpt[A](tree: Tree)(parseInner: Tree => Opt[A])(using Raise): Opt[Opt[A]] = tree match
Expand Down Expand Up @@ -363,6 +375,9 @@ object ConfigParser:
case "commentGeneratedCode" => parseBool(value) match
case S(v) => _.copy(commentGeneratedCode = v)
case N => identity
case "target" => parseCompilationTarget(value) match
case S(v) => _.copy(target = v)
case N => identity
case "effectHandlers" =>
cfg =>
parseOpt(value)(v => parseEffectHandlers(v, cfg.effectHandlers)) match
Expand Down
Loading