diff --git a/hkmc2/js/src/test/scala/hkmc2/CompilerTest.scala b/hkmc2/js/src/test/scala/hkmc2/CompilerTest.scala index e8fc6a727f..44a5cb18bd 100644 --- a/hkmc2/js/src/test/scala/hkmc2/CompilerTest.scala +++ b/hkmc2/js/src/test/scala/hkmc2/CompilerTest.scala @@ -8,6 +8,11 @@ import scala.scalajs.js.annotation._ import scala.scalajs.js.Dynamic.global class CompilerTest extends AnyFunSuite: + private def hasErrors(diagnostics: js.Array[js.Dynamic]): Bool = + diagnostics.exists: perFile => + val fileDiagnostics = perFile.diagnostics.asInstanceOf[js.Array[js.Dynamic]] + fileDiagnostics.exists(_.kind is "error") + private def loadStandardLibrary(): Map[String, String] = val projectRoot = node.process.cwd() val compilePath = node.path.join(projectRoot, "hkmc2", "shared", "src", "test", "mlscript-compile") @@ -57,10 +62,7 @@ class CompilerTest extends AnyFunSuite: val diagnostics = compiler.compile(inputPath) - val hasErrors = diagnostics.exists: perFile => - val fileDiagnostics = perFile.diagnostics.asInstanceOf[scala.scalajs.js.Array[scala.scalajs.js.Dynamic]] - fileDiagnostics.exists(_.kind is "error") - assert(!hasErrors, "Compilation should succeed without errors") + assert(!hasErrors(diagnostics), "Compilation should succeed without errors") val outputExists = fs.exists(Path(outputPath)) assert(outputExists, "Output JavaScript file should be generated") @@ -92,6 +94,32 @@ class CompilerTest extends AnyFunSuite: assert(fs.exists(Path("/Foo.mjs")), "First output should exist") assert(fs.exists(Path("/Bar.mjs")), "Second output should exist") + + test("compiler emits wasm artifacts for a wasm-targeted program"): + val (fs, compiler) = createCompiler() + + fs.write("/simpleWasm.mls", + """|#config(target: CompilationTarget.Wasm) + | + |40 + 2 + |""".stripMargin) + + val diagnostics = compiler.compile("/simpleWasm.mls") + + assert(!hasErrors(diagnostics), "Compilation should succeed without errors") + + assert(fs.exists(Path("/simpleWasm.mjs")), "Glue JavaScript file should be generated") + assert(fs.exists(Path("/simpleWasm.wat")), "WAT file should be generated") + assert(fs.exists(Path("/RuntimeWASM.mjs")), "Shared Wasm runtime helper should be generated") + assert(fs.exists(Path("/RuntimeWASM.wat")), "Shared Wasm intrinsic WAT should be generated") + + val glue = fs.read("/simpleWasm.mjs") + assert(glue.contains("__mlx_compileWatFromUrl"), "Glue code should load module WAT from file") + assert(glue.contains("export const __mlx_wasm"), "Glue code should expose the internal wasm loader") + assert(glue.contains("export default"), "Glue code should export the module result by default") + + val wat = fs.read("/simpleWasm.wat") + assert(wat.contains("(module"), "Generated WAT should contain a module") test("compiler can report errors"): val (fs, compiler) = createCompiler() @@ -102,8 +130,4 @@ class CompilerTest extends AnyFunSuite: assert(diagnostics.length is 1, "Should report diagnostics for only one file") - val hasErrors = diagnostics.exists: perFile => - val fileDiagnostics = perFile.diagnostics.asInstanceOf[scala.scalajs.js.Array[scala.scalajs.js.Dynamic]] - fileDiagnostics.exists(_.kind is "error") - - assert(hasErrors, "Compilation should report errors") + assert(hasErrors(diagnostics), "Compilation should report errors") diff --git a/hkmc2/shared/src/main/scala/hkmc2/CompilerCtx.scala b/hkmc2/shared/src/main/scala/hkmc2/CompilerCtx.scala index 968a6efd70..54da97245c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/CompilerCtx.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/CompilerCtx.scala @@ -8,7 +8,7 @@ import mlscript.utils.*, shorthands.* import hkmc2.utils.* import hkmc2.Message.MessageContext import hkmc2.io -import utils.TraceLogger +import utils.TL import semantics.* import Elaborator.* @@ -34,7 +34,7 @@ class CompilerCtx( CompilerCtx(S(newFile, this), beingCompiled + newFile, fs, cache) def getElaboratedBlock - (file: io.Path, prelude: Ctx) + (file: io.Path, prelude: Ctx, full: Bool) (using TL, Raise, Config) : Artifact = @@ -62,18 +62,28 @@ class CompilerCtx( given CompilerCtx = this ParserSetup(file, dbgParsing = false) val resBlk = parse.resultBlk - given Elaborator.Ctx = prelude.copy(mode = Mode.Light).nestLocal("prelude") + val mode = if full then Mode.Full else Mode.Light + given Elaborator.Ctx = prelude.copy(mode = mode).nestLocal("prelude") val elab = given CompilerCtx = derive(parse.origin.fileName) Elaborator(tl, file.up, prelude) - val elabbed = elab.importFrom(resBlk) - Artifact(resBlk, elabbed._1, lastMod) + val (blk, _) = elab.importFrom(resBlk) + // Resolve once here so cached terms are not mutated again by later passes + // (diamond imports would otherwise re-run Resolver and trip expand()). + Resolver(summon[TL])(using summon[Raise], summon[State], summon[Ctx], summon[Config]) + .traverseBlock(blk)(using Resolver.ICtx.empty) + Artifact(resBlk, blk, lastMod, full) cache.upsert(file): case N => mk case cur @ S(art) => - if art.lastChangedTimestamp < lastMod then mk + if art.lastChangedTimestamp < lastMod || !art.mode && full then mk else art + + def getCachedElaboratedBlock(file: io.Path, full: Bool): Opt[Artifact] = + val lastMod = fs.getLastChangedTimestamp(file) + cache.elabCache.get(file).filter: art => + art.lastChangedTimestamp >= lastMod && (art.mode || !full) object CompilerCtx: @@ -88,7 +98,12 @@ end CompilerCtx object CompilerCache: - class Artifact(val tree: syntax.Tree.Block, val term: semantics.Term.Blk, val lastChangedTimestamp: Long) + class Artifact( + val tree: syntax.Tree.Block, + val term: semantics.Term.Blk, + val lastChangedTimestamp: Long, + val mode: Bool, + ) end CompilerCache diff --git a/hkmc2/shared/src/main/scala/hkmc2/Config.scala b/hkmc2/shared/src/main/scala/hkmc2/Config.scala index 4d8e98f2dc..05beff1559 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/Config.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/Config.scala @@ -212,6 +212,18 @@ object ConfigParser: source = Diagnostic.Source.Compilation)) N + private def parseCompilationTarget(tree: Tree)(using Raise): Opt[CompilationTarget] = + tree.deparenthesized match + case Sel(Ident("CompilationTarget"), Ident("JS")) | Ident("JS") => + S(CompilationTarget.JS) + case Sel(Ident("CompilationTarget"), Ident("Wasm")) | Ident("Wasm") => + S(CompilationTarget.Wasm) + case _ => + raise(ErrorReport( + msg"Expected CompilationTarget.JS or CompilationTarget.Wasm" -> tree.toLoc :: Nil, + source = Diagnostic.Source.Compilation)) + N + /** Parse the `None`/`Some(...)` syntax for optional config fields. * Also accepts unwrapped values as a convenience (treated as `Some(value)`). */ private def parseOpt[A](tree: Tree)(parseInner: Tree => Opt[A])(using Raise): Opt[Opt[A]] = tree match @@ -363,6 +375,9 @@ object ConfigParser: case "commentGeneratedCode" => parseBool(value) match case S(v) => _.copy(commentGeneratedCode = v) case N => identity + case "target" => parseCompilationTarget(value) match + case S(v) => _.copy(target = v) + case N => identity case "effectHandlers" => cfg => parseOpt(value)(v => parseEffectHandlers(v, cfg.effectHandlers)) match diff --git a/hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala b/hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala index 773b3289b0..04d1bbd345 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala @@ -6,10 +6,15 @@ import mlscript.utils.*, shorthands.* import hkmc2.io import utils.* +import hkmc2.Message.MessageContext import hkmc2.semantics.* import hkmc2.syntax.Keyword.`override` import semantics.Elaborator.{Ctx, State} +import codegen.wasm.text.{ + CompiledWasmModule, SessionFunc, SessionGlobal, SessionSingleton, WatBuilder, +} + class ParserSetup(file: io.Path, dbgParsing: Bool)(using state: Elaborator.State, raise: Raise, cctx: CompilerCtx): @@ -41,24 +46,288 @@ object MLsCompiler: def termFile: io.Path /** - * The compiler that compiles MLscript code into JavaScript modules. + * The compiler that compiles MLscript code into JavaScript or WebAssembly modules. * * @param paths required paths needed by the compiler * @param mkRaise generates a separate `Raise` function for each file. * @param config the compiler's configuration object - * @param fs the file system interface + * @param cctx the compilation context (including the file system interface) */ class MLsCompiler (paths: MLsCompiler.Paths, mkRaise: io.Path => Raise) (using cctx: CompilerCtx, config: Config): import paths.* + given DebugPrinter = new DebugPrinter + val etl = new TraceLogger{override def doTrace: Bool = false} + val ltl = new TraceLogger{override def doTrace: Bool = false} + val rtl = new TraceLogger{override def doTrace: Bool = false} var dbgParsing = false var dbgElab = false + /** Symbols a module wants preserved through optimization passes and surfaced to + * downstream importers: the exported module itself, its members, and their type symbols. */ + private def preservedSymbolsFor(exportedSymbol: Opt[BlockMemberSymbol]): Set[codegen.Local] = + val members = exportedSymbol.iterator.flatMap: sym => + sym.modOrObjTree.iterator.flatMap(_.definedSymbols.valuesIterator) + .toSet + (exportedSymbol.iterator ++ members.iterator ++ members.iterator.flatMap(_.tsym.toSeq)) + .map(sym => sym: codegen.Local).toSet + + /** JS-specific emit: build module source and write the `.mjs` file. */ + private def emitJs( + file: io.Path, + wd: io.Path, + program: codegen.Program, + exportedSymbol: Opt[BlockMemberSymbol], + )(using Raise, State, Ctx, Config): Unit = + val jsb = ltl.givenIn: + codegen.js.JSBuilder() + val baseScp: utils.Scope = + utils.Scope.empty(utils.Scope.Cfg.default) + // * This line serves for `import.meta.url`, which retrieves directory and file names of mjs files. + // * Having `module id"import" with ...` in `prelude.mls` will generate `globalThis.import` that is undefined. + baseScp.addToBindings(Elaborator.State.importSymbol, "import", shadow = false) + val nestedScp = baseScp.nest + val je = nestedScp.givenIn: + jsb.program(program, exportedSymbol, wd) + cctx.fs.write(file.up / io.RelPath(s"${file.baseName}.mjs"), je.stripBreaks.mkString(100)) + + /** Resolves a Wasm module's `.mjs` import path back to its `.mls` source, if present on disk. */ + private def resolveImportSource(importPath: Str, moduleDir: io.Path): Opt[io.Path] = + val importedPath = + if importPath.startsWith("/") + then io.Path(importPath) + else moduleDir / io.RelPath(importPath) + if importedPath.ext === "mjs" then + val resolved = importedPath.up / io.RelPath(s"${importedPath.baseName}.mls") + resolved.optionIf(cctx.fs.exists(resolved)) + else N + + /** JavaScript glue that loads the module's `.wat`, links its Wasm dependencies, + * and re-exports the program result. */ + private def wasmGlue( + wd: io.Path, + moduleBaseName: Str, + compiled: CompiledWasmModule, + exportedSymbol: Opt[BlockMemberSymbol], + dependencies: Seq[(Str, Str)], + ): Str = + def relativeImportPath(path: Str): Str = + if path.startsWith("/") + then "./" + io.Path(path).relativeTo(wd).map(_.toString).getOrElse(path) + else path + val dependencyImports = dependencies.zipWithIndex.map: + case ((importPath, _), idx) => + s"""import { __mlx_wasm as __mlx_dep_$idx } from "${relativeImportPath(importPath)}";""" + val dependencyImportStmts = + if dependencyImports.isEmpty then "" + else dependencyImports.mkString("", "\n", "\n\n") + val dependencyBindings = dependencies.zipWithIndex.map: + case ((_, moduleName), idx) => + s""" importObject[${moduleName.escaped}] = (await __mlx_dep_$idx()).instance.exports""" + val dependencyBindingStmts = + if dependencyBindings.isEmpty then "" + else dependencyBindings.mkString("", "\n", "\n") + val exportedValueExpr = exportedSymbol.flatMap: sym => + compiled.sessionExports.collectFirst: + case f: SessionFunc if f.sym === sym => + if sym.modOrObjTree.nonEmpty then s"""instance.exports[${f.exportName.escaped}]()""" + else s"""instance.exports[${f.exportName.escaped}]""" + case g: SessionGlobal if g.sym === sym => s"""instance.exports[${g.exportName.escaped}].value""" + case s: SessionSingleton if s.blockSym === sym => s"""instance.exports[${s.exportName.escaped}].value""" + .getOrElse(s"""instance.exports[${compiled.entryName.escaped}]()""") + s"""|import { __mlx_compileWatFromUrl, __mlx_buildSystem } from "./RuntimeWASM.mjs" + |$dependencyImportStmts + | + |const __mlx_wat_url = new URL("./${moduleBaseName}.wat", import.meta.url) + | + |async function __mlx_importObject() { + | const importObject = { + | system: await __mlx_buildSystem(${compiled.systemMemMinPages}) + | } + |$dependencyBindingStmts + | return importObject + |} + | + |const __mlx_wasmPromise = __mlx_importObject() + | .then(importObject => __mlx_compileWatFromUrl(__mlx_wat_url, importObject)) + | + |export const __mlx_wasm = () => __mlx_wasmPromise + | + |const { instance } = await __mlx_wasm() + | + |export default $exportedValueExpr + |""".stripMargin + + /** Source of the shared `RuntimeWASM.mjs` helper module. */ + private val wasmRuntimeMjsSource: Str = + """|import binaryen from "binaryen" + | + |async function __mlx_loadWatText(url) { + | if (url.protocol === "file:") { + | const { readFile } = await import("node:fs/promises") + | return readFile(url, "utf8") + | } + | const response = await fetch(url) + | if (!response.ok) throw new Error(`Failed to load WAT from ${url}`) + | return response.text() + |} + | + |export async function __mlx_compileWatFromUrl(url, importObject) { + | const wat = await __mlx_loadWatText(url) + | const mod = binaryen.parseText(wat) + | mod.setFeatures(binaryen.Features.All) + | if (!mod.validate()) throw new Error(`Generated WAT is invalid: ${url}`) + | const modBuf = mod.emitBinary() + | mod.dispose() + | return WebAssembly.instantiate(modBuf, importObject) + |} + | + |const __mlx_intrinsicsPromise = + | __mlx_compileWatFromUrl(new URL("./RuntimeWASM.wat", import.meta.url), {}) + | + |export async function __mlx_buildSystem(systemMemMinPages) { + | const mem = new WebAssembly.Memory({ initial: systemMemMinPages }) + | const decodeUtf16 = new TextDecoder("utf-16le") + | const intrinsicModule = await __mlx_intrinsicsPromise + | return { + | mem, + | mlx_str_from_utf16: (ptr, byteLen) => + | decodeUtf16.decode(new Uint8Array(mem.buffer, ptr, byteLen)), + | ...intrinsicModule.instance.exports + | } + |} + |""".stripMargin + + /** Ensures the shared Wasm runtime helpers (`RuntimeWASM.{wat,mjs}`) are present + * in the given output directory. */ + private def ensureWasmRuntimeArtifacts(moduleDir: io.Path)(using Raise, State): Unit = + val watFile = moduleDir / io.RelPath("RuntimeWASM.wat") + val mjsFile = moduleDir / io.RelPath("RuntimeWASM.mjs") + if !cctx.fs.exists(watFile) then + val baseScp = utils.Scope.empty(utils.Scope.Cfg.default) + val watb = ltl.givenIn: + new WatBuilder() + val wat = baseScp.nest.givenIn: + watb.intrinsicSupportModule().mkString(100) + cctx.fs.write(watFile, wat) + if !cctx.fs.exists(mjsFile) then + cctx.fs.write(mjsFile, wasmRuntimeMjsSource) + + /** Compiles an imported Wasm module, or scans cached full elaboration for its session exports. */ + private def compileWasmDep( + file: io.Path, + prelude: Ctx, + memo: mutable.Map[io.Path, Opt[CompiledWasmModule]], + )(using State, CompilerCtx, SymbolPrinter): Opt[CompiledWasmModule] = + memo.get(file) match + case S(cached) => cached + case N => + given Raise = mkRaise(file) + val cachedArtifact = CompilerCtx.get.getCachedElaboratedBlock(file, full = true) + val artifact = cachedArtifact.getOrElse: + etl.givenIn: + CompilerCtx.get.getElaboratedBlock(file, prelude, true) + val blk = artifact.term + val effectiveCfg = Config.extractConfigFromStats(blk) + val compiled = + if effectiveCfg.target =/= CompilationTarget.Wasm then + raise(ErrorReport( + msg"Wasm modules can only import Wasm-targeted `.mls` files; " + + msg"`${file.toString}` targets `${effectiveCfg.target.toString}`" -> N :: Nil, + source = Diagnostic.Source.Compilation)) + N + else + prelude.givenIn: + val exportedSymbol = artifact.tree.definedSymbols.find(_._1 === file.baseName).map(_._2) + val preservedSymbols = preservedSymbolsFor(exportedSymbol) + effectiveCfg.givenIn: + val low = ltl.givenIn: + new codegen.Lowering() with codegen.LoweringSelSanityChecks + val le0 = low.program(blk) + val le1 = ltl.givenIn: + codegen.BlockSimplifier(preservedSymbols)(le0) + val program = ltl.givenIn: + codegen.DeadParamElim(le1) + val symbolsToExport = preservedSymbols ++ program.main.definedVars + cachedArtifact match + case S(_) => + val watb = ltl.givenIn: + new WatBuilder() + val exportScope = utils.Scope.empty(utils.Scope.Cfg.default).nest + val sessionExports = exportScope.givenIn: + watb.extractExports( + program.main, + Seq.empty, + symbolsToExport, + file.baseName, + ) + S(CompiledWasmModule(hkmc2.document.Document.empty, "entry", 0, sessionExports)) + case N => + emitWasm(file, program, exportedSymbol, preservedSymbols, prelude, memo) + memo(file) = compiled + compiled + + /** Wasm-specific emit: compile any `.mls` dependencies, run `WatBuilder`, write + * `.wat`, `.mjs` glue, and ensure shared `RuntimeWASM.{wat,mjs}` helpers exist. */ + private def emitWasm( + file: io.Path, + program: codegen.Program, + exportedSymbol: Opt[BlockMemberSymbol], + preservedSymbols: Set[codegen.Local], + prelude: Ctx, + memo: mutable.Map[io.Path, Opt[CompiledWasmModule]], + )(using Raise, State, CompilerCtx, SymbolPrinter): Opt[CompiledWasmModule] = + val moduleDir = file.up + val moduleBaseName = file.baseName + val edges = mutable.ArrayBuffer.empty[(Str, Str)] + val compiledDeps = mutable.ArrayBuffer.empty[CompiledWasmModule] + var failed = false + program.imports.foreach: + case (sym, importPath) => + resolveImportSource(importPath, moduleDir) match + case S(sourceFile) => + compileWasmDep(sourceFile, prelude, memo) match + case S(dep) => + edges += (importPath -> sourceFile.baseName) + compiledDeps += dep + case N => failed = true + case N => + raise(ErrorReport( + msg"Wasm modules currently only support imports of Wasm-targeted `.mls` files; " + + msg"`${importPath}` cannot be resolved that way" -> sym.toLoc :: Nil, + source = Diagnostic.Source.Compilation)) + failed = true + if failed then N + else + val sessionImports = + val seen = mutable.LinkedHashSet.empty[Str] + compiledDeps.iterator.flatMap(_.sessionExports).filter(b => seen.add(b.bindingKey)).toSeq + val baseScp = utils.Scope.empty(utils.Scope.Cfg.default) + val watb = ltl.givenIn: + new WatBuilder() + val symbolsToExport = preservedSymbols ++ program.main.definedVars + val compiled = baseScp.nest.givenIn: + watb.program( + program, + exportedSymbol, + moduleDir, + sessionImports, + symbolsToExport, + moduleBaseName, + ) + ensureWasmRuntimeArtifacts(moduleDir) + cctx.fs.write(moduleDir / io.RelPath(s"$moduleBaseName.wat"), compiled.wat.mkString(100)) + cctx.fs.write( + moduleDir / io.RelPath(s"$moduleBaseName.mjs"), + wasmGlue(moduleDir, moduleBaseName, compiled, exportedSymbol, edges.toSeq), + ) + S(compiled) + def compileModule(file: io.Path): Unit = val wd = file.up @@ -68,7 +337,6 @@ class MLsCompiler given Elaborator.State = new Elaborator.State: override def dbg: Bool = dbgElab - // TODO adapt logic given SymbolPrinter = new SymbolPrinter( Scope.empty(Scope.Cfg.default.copy( escapeChars = false, @@ -76,11 +344,6 @@ class MLsCompiler includeZero = true, )) ) - val etl = new TraceLogger{override def doTrace: Bool = false} - val ltl = new TraceLogger{override def doTrace: Bool = false} - // val ltl = new TraceLogger{override def doTrace: Bool = true} - val rtl = new TraceLogger{override def doTrace: Bool = false} - val preludeParse = ParserSetup(preludeFile, dbgParsing) val mainParse = ParserSetup(file, dbgParsing) @@ -88,56 +351,51 @@ class MLsCompiler val initState = State.init.nestLocal("prelude") - val (pblk, newCtx) = elab.importFrom(preludeParse.resultBlk)(using initState) + val (_, newCtx) = elab.importFrom(preludeParse.resultBlk)(using initState) newCtx.nestLocal("file:"+file.baseName).givenIn: given CompilerCtx = cctx.derive(file) val elab = Elaborator(etl, wd, newCtx) val parsed = mainParse.resultBlk val (blk0, _) = elab.importFrom(parsed) - Config.extractConfigFromStats(blk0).givenIn { - val resolver = Resolver(rtl) - resolver.traverseBlock(blk0)(using Resolver.ICtx.empty) - def findQuote(t: semantics.Statement): Bool = t match - case Term.Quoted(_) | Term.Unquoted(_) => true - case Term.Ref(sym) => sym === State.termSymbol - case _ => t.subTerms.exists(findQuote) - val hasQuote = findQuote(blk0) - val blk = new Term.Blk( - Import(State.runtimeSymbol, runtimeFile.toString, runtimeFile) :: - // Only import `Term.mls` when necessary. - (if hasQuote then - Import(State.termSymbol, termFile.toString, termFile) :: blk0.stats - else - blk0.stats), - blk0.res - ) - val low = ltl.givenIn: - new codegen.Lowering() - with codegen.LoweringSelSanityChecks - val jsb = ltl.givenIn: - codegen.js.JSBuilder() - val le_0 = low.program(blk) - val nme = file.baseName - val exportedSymbol = parsed.definedSymbols.find(_._1 === nme).map(_._2) - val le_1 = ltl.givenIn: - codegen.BlockSimplifier(exportedSymbol.toSet)(le_0) - val le_2 = ltl.givenIn: - codegen.DeadParamElim(le_1) - val baseScp: utils.Scope = - utils.Scope.empty(utils.Scope.Cfg.default) - // * This line serves for `import.meta.url`, which retrieves directory and file names of mjs files. - // * Having `module id"import" with ...` in `prelude.mls` will generate `globalThis.import` that is undefined. - baseScp.addToBindings(Elaborator.State.importSymbol, "import", shadow = false) - val nestedScp = baseScp.nest - val je = nestedScp.givenIn: - jsb.program(le_2, exportedSymbol, wd) - val jsStr = je.stripBreaks.mkString(100) - val out = file.up / io.RelPath(file.baseName + ".mjs") - cctx.fs.write(out, jsStr) - } - + val effectiveCfg = Config.extractConfigFromStats(blk0) + effectiveCfg.givenIn: + val resolver = Resolver(rtl) + resolver.traverseBlock(blk0)(using Resolver.ICtx.empty) + def findQuote(t: semantics.Statement): Bool = t match + case Term.Quoted(_) | Term.Unquoted(_) => true + case Term.Ref(sym) => sym === State.termSymbol + case _ => t.subTerms.exists(findQuote) + val hasQuote = findQuote(blk0) + val exportedSymbol = parsed.definedSymbols.find(_._1 === file.baseName).map(_._2) + val preservedSymbols = preservedSymbolsFor(exportedSymbol) + val blk = effectiveCfg.target match + case CompilationTarget.JS => + new Term.Blk( + Import(State.runtimeSymbol, runtimeFile.toString, runtimeFile) :: + // Only import `Term.mls` when necessary. + (if hasQuote then + Import(State.termSymbol, termFile.toString, termFile) :: blk0.stats + else + blk0.stats), + blk0.res + ) + case CompilationTarget.Wasm => + blk0 + val low = ltl.givenIn: + new codegen.Lowering() + with codegen.LoweringSelSanityChecks + val le_0 = low.program(blk) + val le_1 = ltl.givenIn: + codegen.BlockSimplifier(preservedSymbols)(le_0) + val le_2 = ltl.givenIn: + codegen.DeadParamElim(le_1) + effectiveCfg.target match + case CompilationTarget.JS => + emitJs(file, wd, le_2, exportedSymbol) + case CompilationTarget.Wasm => + emitWasm(file, le_2, exportedSymbol, preservedSymbols, newCtx, mutable.Map.empty) + end MLsCompiler - diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala index a34296df20..51e6a8e91b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala @@ -50,9 +50,10 @@ final case class SessionFunc( moduleName: Str, exportName: Str, funcType: FunctionType, + aliasSyms: Seq[Local], ) extends SessionBinding: def bindingKey: Str = s"func:$moduleName:$exportName" - def bindingSyms: Seq[Local] = sym :: Nil + def bindingSyms: Seq[Local] = sym +: aliasSyms override def exportNameOpt: Opt[Str] = S(exportName) /** Metadata for an exported global that later Wasm REPL modules can import. @@ -150,25 +151,6 @@ final case class CompiledWasmModule( sessionExports: Seq[SessionBinding], ) -/** Context used while collecting REPL/session exports for a single Wasm module. - * - * @param symbolsToExport - * The symbols from the current module that should be recorded as session exports. - * @param collectedBindings - * The session bindings accumulated while compiling the current module. - */ -final class SessionExportCtx( - val symbolsToExport: Set[Local], - val collectedBindings: ArrayBuf[SessionBinding], -): - def shouldExport(sym: Local): Bool = symbolsToExport(sym) - - def emit(binding: SessionBinding): Unit = - collectedBindings += binding - - def freshCollector(): SessionExportCtx = - SessionExportCtx(symbolsToExport, ArrayBuf.empty) - /** A Wasm function and its associated information. * * Each instance of [[FuncInfo]] represents a single function definition in a WebAssembly module. diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index a17473b124..32b89860d5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -112,6 +112,30 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private def baseObjectRefType(nullable: Bool)(using Ctx): RefType = RefType(baseObjectTypeIdx, nullable = nullable) + /** Registers the Wasm intrinsic struct types shared across class lowering. */ + private def registerBaseRuntimeTypes()(using Ctx, Raise): Unit = + ctx.addType(TypeInfo( + sym = typeInfoBaseSym, + compType = StructType(Seq( + tagFieldSym -> Field(I32Type, mutable = false, id = "$tag"), + parentFieldSym -> Field( + RefType(TypeIdx(SymIdx("TypeInfoBase")), nullable = true), + mutable = false, + id = "$parent", + ), + )), + objectTag = N, + )) + + ctx.addType(TypeInfo( + sym = baseObjectSym, + compType = StructType(Seq( + typeInfoFieldSym -> Field(RefType(typeInfoBaseTypeIdx, nullable = false), mutable = true, id = "$typeinfo"), + )), + objectTag = S(ctx.getFreshObjectTag() ensuring (_ == 0)), + )) + end registerBaseRuntimeTypes + /** Returns the default Wasm value for one struct field when eagerly constructing an object instance. */ private def defaultStructFieldValue(field: Field)(using Ctx, Raise): Expr = field.ty match case refTy: RefType if refTy.nullable => ref.`null`(refTy.heapType) @@ -186,7 +210,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: && defn.auxParams.isEmpty && (!(defn.k is syntax.Obj) || defn.parentPath.isEmpty) && (!(defn.k is syntax.Obj) || defn.methods.isEmpty) - && defn.companion.isEmpty + && defn.companion.forall(_.methods.isEmpty) /** Returns singleton metadata when `sym` resolves to a registered singleton object. */ private def singletonInfoFor(sym: Local)(using Ctx): Opt[SingletonInfo] = @@ -216,46 +240,23 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: bufferable = N, )(N, Nil) - /** Registers the synthetic `Unit` singleton. */ - private def RegisterUnitSingleton()(using Ctx, FunctionCtx, Raise, SessionExportCtx): Unit = + private def ensureSyntheticUnitPredeclared()(using Ctx, Raise): Unit = val unitDefn = syntheticUnitDefn - val singletonOwner = unitDefn.isym match - case mos: ModuleOrObjectSymbol => S(mos) - case _ => N - if ctx.containsSingleton(unitDefn.sym) then return - if ctx.getType(unitDefn.sym).isEmpty then predeclareClassTypeInfoType(unitDefn) predeclareClassType(unitDefn) predeclareClassInit(unitDefn) predeclareClassConstructor(unitDefn) - predeclareClassTypeInfoGlobal(unitDefn) + predeclareClassTypeInfoGlobal(unitDefn, Set[Local](unitDefn.sym)) - returningTerm(Define(unitDefn, End(""))) + /** Registers the synthetic `Unit` singleton. */ + private def RegisterUnitSingleton()(using Ctx, FunctionCtx, Raise): Unit = + val unitDefn = syntheticUnitDefn + if ctx.containsSingleton(unitDefn.sym) then return - val typeInfo = ctx.getTypeInfo_!(unitDefn.sym) - val unitRttiTypeInfo = ctx.getTypeInfo_!(typeInfoTypeIdxs(unitDefn.sym)) - val unitTypeInfoGlobalInfo = ctx.getGlobalInfo_!(typeInfoGlobals(unitDefn.sym)) - val singletonInfo = ctx.getSingletonInfo(unitDefn.sym) getOrElse: - lastWords("Missing singleton metadata for synthetic Unit object") - // Record session metadata for the synthetic Unit singleton. - summon[SessionExportCtx].emit(SessionClass( - sym = unitDefn.sym, - wrapId = typeInfo.wrapId, - compType = typeInfo.compType, - objectTag = typeInfo.objectTag, - rttiTypeInfo = unitRttiTypeInfo, - rttiGlobalExportName = unitTypeInfoGlobalInfo.exportName.get, - aliasSyms = singletonOwner.toSeq, - )) - summon[SessionExportCtx].emit(SessionSingleton( - blockSym = unitDefn.sym, - wrapId = N -> N, - objectSym = singletonOwner, - moduleName = SessionBinding.ReplModuleName, - exportName = singletonInfo.globalName, - globalTy = singletonInfo.globalTy, - )) + ensureSyntheticUnitPredeclared() + + returningTerm(Define(unitDefn, End(""))) end RegisterUnitSingleton /** Registers eager singleton runtime state by creating its global and start-init action. */ @@ -380,15 +381,14 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: end sortTopLevelClasses /** Returns the elaborated semantic class definition for this lowered class. */ - private def semanticClassDef(defn: ClsLikeDefn)(using Raise): hkmc2.semantics.ClassLikeDef = + private def semanticClassDef(defn: ClsLikeDefn): Opt[hkmc2.semantics.ClassLikeDef] = defn.isym.defn match - case S(clsDef: hkmc2.semantics.ClassLikeDef) => clsDef - case _ => - lastWords(s"Expected definition of class `${defn.sym}` to be present") + case S(clsDef: hkmc2.semantics.ClassLikeDef) => S(clsDef) + case _ => N /** Returns the elaborated source methods for this class. */ - private def semanticMethodDefs(defn: ClsLikeDefn)(using Raise): List[TermDefinition] = - semanticClassDef(defn).body.methods.filter(_.body.nonEmpty) + private def semanticMethodDefs(defn: ClsLikeDefn): List[TermDefinition] = + semanticClassDef(defn).fold(Nil)(_.body.methods.filter(_.body.nonEmpty)) /** Resolves the exact overridden parent method symbol for `methodDef`, if any. */ private def overriddenParentMethodSym( @@ -546,19 +546,21 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: p.sym -> SymIdx(p.sym.nme) val ctorExportName = defn.sym .optionIf: sym => - !(defn.k is syntax.Obj) && sym.nameIsMeaningful + (!(defn.k is syntax.Obj) || defn.sym === State.unitBlockMemberSymbol) && sym.nameIsMeaningful .map(_.nme) + .map: name => + if defn.sym === State.unitBlockMemberSymbol then s"${State.unitBlockMemberSymbol.nme}_ctor" else name predeclareClassFunc(defn, "ctor", ctorParams, defn.sym, ctorExportName) /** Registers all Wasm pre-declarations needed for one top-level class, in dependency order. */ - private def predeclareClass(defn: ClsLikeDefn)(using Ctx, Raise, SessionExportCtx): Unit = + private def predeclareClass(defn: ClsLikeDefn, symbolsToExport: Set[Local])(using Ctx, Raise): Unit = predeclareClassVirtualTable(defn) predeclareClassTypeInfoType(defn) predeclareClassType(defn) predeclareClassInit(defn) predeclareClassConstructor(defn) predeclareClassMethods(defn) - predeclareClassTypeInfoGlobal(defn) + predeclareClassTypeInfoGlobal(defn, symbolsToExport) /** Collects the symbols that should live in mutable globals so later REPL blocks can import them. * @@ -567,7 +569,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: */ private def collectSessionGlobalSymbols( b: Block, - sessionExportCtx: SessionExportCtx, + symbolsToExport: Set[Local], ): Set[Symbol] = def restOf(block: Block): Opt[Block] = block match case Define(_, rst) => S(rst) @@ -584,11 +586,11 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: recur(body) case Begin(sub, rst) => recur(sub) ++ recur(rst) - case Define(ValDefn(_, sym, _), rst) if sessionExportCtx.shouldExport(sym) => + case Define(ValDefn(_, sym, _), rst) if symbolsToExport(sym) => recur(rst) + sym case Define(_, rst) => recur(rst) - case Assign(sym: Symbol, _, rst) if sessionExportCtx.shouldExport(sym) => + case Assign(sym: Symbol, _, rst) if symbolsToExport(sym) => recur(rst) + sym case _: BlockTail => Set.empty @@ -598,10 +600,190 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: recur(b) end collectSessionGlobalSymbols + def extractExports( + block: Block, + sessionImports: Seq[SessionBinding], + symbolsToExport: Set[Local], + currentModuleName: Str, + )(using Raise): Seq[SessionBinding] = + val ctx = Ctx.empty + given Ctx = ctx + val exports = ArrayBuf.empty[SessionBinding] + + registerBaseRuntimeTypes() + registerSessionImports(sessionImports) + + collectSessionGlobalSymbols(block, symbolsToExport).toSeq.sortBy(_.uid).foreach: sym => + exports += SessionGlobal( + sym = sym, + wrapId = N -> N, + moduleName = currentModuleName, + exportName = sym.nme, + globalType = GlobalType(RefType.anyref, mutable = true), + ) + + def scanTopLevel(block: Block)(onExportedDefn: Defn => Unit): Unit = + block match + case Scoped(_, body) => + scanTopLevel(body)(onExportedDefn) + case Begin(sub, rst) => + scanTopLevel(sub)(onExportedDefn) + scanTopLevel(rst)(onExportedDefn) + case Define(defn, rst) => + if symbolsToExport(defn.sym) then onExportedDefn(defn) + scanTopLevel(rst)(onExportedDefn) + case Match(_, _, _, rst) => + scanTopLevel(rst)(onExportedDefn) + case TryBlock(_, _, rst) => + scanTopLevel(rst)(onExportedDefn) + case Label(_, _, _, rst) => + scanTopLevel(rst)(onExportedDefn) + case Assign(_, _, rst) => + scanTopLevel(rst)(onExportedDefn) + case AssignField(_, _, _, rst) => + scanTopLevel(rst)(onExportedDefn) + case AssignDynField(_, _, _, _, rst) => + scanTopLevel(rst)(onExportedDefn) + case _: BlockTail => + () + + var usesUnit = false + new BlockTraverser: + override def applyValue(v: Value): Unit = v match + case Value.Ref(l, disamb) => + if (l is State.unitSymbol) || disamb.contains(State.unitSymbol) then + usesUnit = true + super.applyValue(v) + case _ => + super.applyValue(v) + + override def applyPath(p: Path): Unit = p match + case sel @ Select(_, _) => + sel.symbol.foreach: + case sym: ModuleOrObjectSymbol if sym is State.unitSymbol => + usesUnit = true + case _ => () + super.applyPath(sel) + case _ => + super.applyPath(p) + .applyBlock(block) + + boundary[Unit]: + given Raise = diag => + diag match + case _: ErrorReport => break(()) + case _ => () + val ordered = sortTopLevelClasses(collectTopLevelClassDefns(block)) + ordered.foreach: defn => + predeclareClass(defn, symbolsToExport) + + if usesUnit && !ctx.containsSingleton(State.unitBlockMemberSymbol) then + ensureSyntheticUnitPredeclared() + registerSingletonInit(syntheticUnitDefn, ctx.getType_!(State.unitBlockMemberSymbol)) + val unitDefn = syntheticUnitDefn + val singletonOwner = unitDefn.isym match + case mos: ModuleOrObjectSymbol => S(mos) + case _ => N + val typeInfo = ctx.getTypeInfo_!(unitDefn.sym) + val rttiTypeInfo = ctx.getTypeInfo_!(typeInfoTypeIdxs(unitDefn.sym)) + val rttiGlobalInfo = ctx.getGlobalInfo_!(typeInfoGlobals(unitDefn.sym)) + val singletonInfo = ctx.getSingletonInfo(unitDefn.sym) getOrElse: + lastWords("Missing singleton metadata for synthetic Unit object") + exports += SessionClass( + sym = unitDefn.sym, + wrapId = typeInfo.wrapId, + compType = typeInfo.compType, + objectTag = typeInfo.objectTag, + rttiTypeInfo = rttiTypeInfo, + rttiGlobalExportName = rttiGlobalInfo.exportName.get, + aliasSyms = singletonOwner.toSeq, + ) + ctx.getFuncInfo(unitDefn.sym).foreach: ctorInfo => + ctorInfo.exportName.foreach: exportName => + exports += SessionFunc( + sym = unitDefn.sym, + wrapId = ctorInfo.wrapId, + moduleName = currentModuleName, + exportName = exportName, + funcType = FunctionType(ctorInfo.getSignatureType), + aliasSyms = singletonOwner.toSeq, + ) + exports += SessionSingleton( + blockSym = unitDefn.sym, + objectSym = singletonOwner, + wrapId = N -> N, + moduleName = currentModuleName, + exportName = singletonInfo.globalName, + globalTy = singletonInfo.globalTy, + ) + + scanTopLevel(block): + case FunDefn(owner, sym, _, ps :: _, _) if owner.isEmpty && sym.nameIsMeaningful => + exports += SessionFunc( + sym = sym, + wrapId = N -> N, + moduleName = currentModuleName, + exportName = sym.nme, + funcType = FunctionType( + params = ps.params.map(p => WasmParam(SymIdx(p.sym.nme), RefType.anyref)), + results = Seq(Result(RefType.anyref)), + ), + aliasSyms = Nil, + ) + case clsLikeDefn: ClsLikeDefn => + ctx.getTypeInfo(clsLikeDefn.sym).foreach: typeInfo => + val isSingletonObj = clsLikeDefn.k is syntax.Obj + val rttiTypeInfo = ctx.getTypeInfo_!(typeInfoTypeIdxs(clsLikeDefn.sym)) + val rttiGlobalInfo = ctx.getGlobalInfo_!(typeInfoGlobals(clsLikeDefn.sym)) + exports += SessionClass( + sym = clsLikeDefn.sym, + wrapId = typeInfo.wrapId, + compType = typeInfo.compType, + objectTag = typeInfo.objectTag, + rttiTypeInfo = rttiTypeInfo, + rttiGlobalExportName = rttiGlobalInfo.exportName.get, + aliasSyms = clsLikeDefn.isym match + case mos: ModuleOrObjectSymbol => mos :: Nil + case _ => Nil, + ) + if !isSingletonObj && clsLikeDefn.sym.nameIsMeaningful then + val ctorParams = clsLikeDefn.paramsOpt.fold(Seq.empty[WasmParam]): ps => + ps.params.map(p => WasmParam(SymIdx(p.sym.nme), RefType.anyref)) + exports += SessionFunc( + sym = clsLikeDefn.sym, + wrapId = S(clsLikeDefn.sym.nme) -> N, + moduleName = currentModuleName, + exportName = clsLikeDefn.sym.nme, + funcType = FunctionType( + params = ctorParams, + results = Seq(Result(RefType.anyref)), + ), + aliasSyms = Nil, + ) + if isSingletonObj then + registerSingletonInit(clsLikeDefn, ctx.getType_!(clsLikeDefn.sym)) + ctx.getSingletonInfo(clsLikeDefn.sym).foreach: info => + val singletonOwner = clsLikeDefn.isym match + case mos: ModuleOrObjectSymbol => S(mos) + case _ => N + exports += SessionSingleton( + blockSym = clsLikeDefn.sym, + objectSym = singletonOwner, + wrapId = typeInfo.wrapId, + moduleName = currentModuleName, + exportName = info.globalName, + globalTy = info.globalTy, + ) + case _ => () + + val seen = scala.collection.mutable.LinkedHashSet.empty[Str] + exports.filter(binding => seen.add(binding.bindingKey)).toSeq + end extractExports + /** Declares a mutable exported global for a REPL-visible binding produced by the current block. */ private def registerSessionGlobal( sym: Symbol, - )(using Ctx, Raise, SessionExportCtx): Unit = + )(using Ctx, Raise): Unit = if ctx.containsGlobal(sym) then return val exportName = sym.nme val globalInfo = GlobalInfo( @@ -611,13 +793,6 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: sym, ) ctx.addGlobal(globalInfo) - summon[SessionExportCtx].emit(SessionGlobal( - sym = sym, - wrapId = globalInfo.wrapId, - moduleName = SessionBinding.ReplModuleName, - exportName = exportName, - globalType = GlobalType(RefType.anyref, mutable = true), - )) end registerSessionGlobal /** Registers imported REPL bindings into the current module before codegen starts. */ @@ -721,7 +896,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: end predeclareClassTypeInfoType /** Predeclares the shared runtime `typeinfo` global for one supported top-level class. */ - private def predeclareClassTypeInfoGlobal(defn: ClsLikeDefn)(using Ctx, Raise, SessionExportCtx): Unit = + private def predeclareClassTypeInfoGlobal( + defn: ClsLikeDefn, + symbolsToExport: Set[Local], + )(using Ctx, Raise): Unit = val typeInfoTypeIdx = typeInfoTypeIdxs(defn.sym) val tagValue = ctx.getTypeInfo_!(defn.sym).objectTag.get val parentTypeInfo = @@ -740,7 +918,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: globalType = GlobalType(RefType(typeInfoTypeIdx, nullable = false), mutable = false), init = struct.`new`(typeInfoTypeIdx, initFields), exportName = - if summon[SessionExportCtx].shouldExport(defn.sym) || defn.sym == syntheticUnitDefn.sym + if symbolsToExport(defn.sym) || defn.sym == syntheticUnitDefn.sym then S(s"${defn.sym.nme}_typeinfo") else N, sym = defn.sym, @@ -878,7 +1056,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: /** Compiles a class init body under its own Wasm-local frame with explicit `this`. */ private def setupInitLocals( clsLikeDefn: ClsLikeDefn, - )(using Ctx, Raise, SessionExportCtx): (Expr, FunctionCtx) = + )(using Ctx, Raise): (Expr, FunctionCtx) = genFuncBody(clsLikeDefn.paramsOpt.toList, thisSym = S(clsLikeDefn.isym)): val thisVar = funcCtx.lookupLocal_!(clsLikeDefn.isym, N) val preCtorWat = compilePreCtor(clsLikeDefn, thisVar) @@ -899,7 +1077,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private def compilePreCtor( clsLikeDefn: ClsLikeDefn, thisVar: LocalIdx, - )(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = + )(using Ctx, FunctionCtx, Raise): Expr = def withRest(block: NonBlockTail, rest: Block): Block = block match case Scoped(syms, _) => Scoped(syms, rest) case Begin(sub, _) => Begin(sub, rest) @@ -961,7 +1139,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private def normalizeEntryExpr( expr: Expr, isAbortive: Bool, - )(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = + )(using Ctx, FunctionCtx, Raise): Expr = if expr.resultTypes.isEmpty && !isAbortive then blockInstr( label = N, @@ -1013,7 +1191,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: loc: Opt[Loc], errCtx: Str, errExtra: => Str, - )(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr => Expr = + )(using Ctx, FunctionCtx, Raise): Expr => Expr = fld match case Value.Lit(IntLit(value)) if value.isValidInt => val idx = value.toInt @@ -1118,7 +1296,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ) end getVar - def argument(a: Arg)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = + def argument(a: Arg)(using Ctx, FunctionCtx, Raise): Expr = if a.spread.nonEmpty then errExpr( Ls(msg"WatBackend::argument for spread expression not implemented yet" -> a.value.toLoc), @@ -1126,10 +1304,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ) else result(a.value) - def operand(a: Arg)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = + def operand(a: Arg)(using Ctx, FunctionCtx, Raise): Expr = if a.spread.nonEmpty then die else subexpression(a.value) - def subexpression(r: codegen.Result)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = r match + def subexpression(r: codegen.Result)(using Ctx, FunctionCtx, Raise): Expr = r match case r: Lambda => errExpr( Ls(msg"WatBuilder::subexpression for Lambda not implemented yet" -> r.toLoc), @@ -1165,7 +1343,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: qual: Path, methodSym: BlockMemberSymbol, args: Seq[Arg], - )(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = + )(using Ctx, FunctionCtx, Raise): Expr = val ownerCls = fieldOwner(methodSym).get ctx.getVirtualTable(ownerCls).flatMap(_.virtualMethodSlots.get(methodSym)) match case S(slot) => @@ -1202,7 +1380,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: returnTypes = Seq(Result(RefType.anyref)), ) - def result(r: codegen.Result)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = r match + def result(r: codegen.Result)(using Ctx, FunctionCtx, Raise): Expr = r match case Value.This(sym) => // TODO(Derppening): Add type tracking and refinement for locals, remove the `ref.cast` ref.cast( @@ -1620,7 +1798,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case S(_) => drop(expr) case N => expr - def returningTerm(t: Block)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = + def returningTerm(t: Block)(using Ctx, FunctionCtx, Raise): Expr = t match case Assign(l, r, rst) if l is State.noSymbol => val rExpr = result(r) @@ -1809,14 +1987,6 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: exportName = sym.optionIf(_.nameIsMeaningful).map(_.nme), ) ctx.addFunc(funcInfo) - if summon[SessionExportCtx].shouldExport(defn.sym) then - summon[SessionExportCtx].emit(SessionFunc( - sym = defn.sym, - wrapId = funcInfo.wrapId, - moduleName = SessionBinding.ReplModuleName, - exportName = sym.nme, - funcType = FunctionType(funcInfo.getSignatureType), - )) nop else @@ -1851,7 +2021,9 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: if isSingletonObj && clsLikeDefn.methods.nonEmpty then break(errUnimplExpr("methods.nonEmpty for object")) if clsLikeDefn.companion.isDefined then - break(errUnimplExpr("companion.isDefined")) + val companion = clsLikeDefn.companion.get + if companion.methods.nonEmpty then + break(errUnimplExpr("companion.methods.nonEmpty")) val ctorAuxParams = clsLikeDefn.auxParams.map: ps => ps.params.map: p => @@ -1960,49 +2132,8 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: lastWords( s"Class method `$methodDefn` with multiple parameter lists should be rejected in predeclaration pass", ) - if summon[SessionExportCtx].shouldExport(clsLikeDefn.sym) then - val rttiTypeInfo = ctx.getTypeInfo_!(typeInfoTypeIdxs(clsLikeDefn.sym)) - val rttiGlobalInfo = ctx.getGlobalInfo_!(typeInfoGlobals(clsLikeDefn.sym)) - summon[SessionExportCtx].emit(SessionClass( - sym = clsLikeDefn.sym, - wrapId = typeinfo.wrapId, - compType = typeinfo.compType, - objectTag = typeinfo.objectTag, - rttiTypeInfo = rttiTypeInfo, - rttiGlobalExportName = rttiGlobalInfo.exportName.get, - aliasSyms = clsLikeDefn.isym match - case mos: ModuleOrObjectSymbol => mos :: Nil - case _ => Nil, - )) - if !isSingletonObj && clsLikeDefn.sym.nameIsMeaningful then - summon[SessionExportCtx].emit(SessionFunc( - sym = clsLikeDefn.sym, - wrapId = ctorFuncInfo.wrapId, - moduleName = SessionBinding.ReplModuleName, - exportName = clsLikeDefn.sym.nme, - funcType = FunctionType( - SignatureType( - params = ctorFnCtx.params.map(p => WasmParam(p._2, RefType.anyref)), - results = Seq(Result(RefType.anyref)), - ), - ), - )) - end if if isSingletonObj then registerSingletonInit(clsLikeDefn, typeref) - if summon[SessionExportCtx].shouldExport(clsLikeDefn.sym) then - ctx.getSingletonInfo(clsLikeDefn.sym).foreach: info => - val singletonOwner = clsLikeDefn.isym match - case mos: ModuleOrObjectSymbol => S(mos) - case _ => N - summon[SessionExportCtx].emit(SessionSingleton( - blockSym = clsLikeDefn.sym, - wrapId = typeinfo.wrapId, - objectSym = singletonOwner, - moduleName = SessionBinding.ReplModuleName, - exportName = info.globalName, - globalTy = info.globalTy, - )) nop @@ -2311,30 +2442,14 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: wd: io.Path, sessionImports: Seq[SessionBinding], preservedSessionSymbols: Set[Local], + currentModuleName: Str, )(using Raise): CompiledWasmModule = - for imprt <- p.imports do - raise( - ErrorReport( - msg"Import of symbol `${imprt._2}` not implemented yet" -> imprt._1.toLoc :: Nil, - extraInfo = S(imprt), - source = Diagnostic.Source.Compilation, - ), - ) - exprt.foreach: exprt => - raise( - ErrorReport( - msg"Export of symbol `${exprt.nme}` not implemented yet" -> exprt.toLoc :: Nil, - extraInfo = S(exprt), - source = Diagnostic.Source.Compilation, - ), - ) - - val sessionExportCtx = SessionExportCtx( - symbolsToExport = preservedSessionSymbols, - collectedBindings = ArrayBuf.empty, + val sessionExports = extractExports( + p.main, + sessionImports, + preservedSessionSymbols, + currentModuleName, ) - given SessionExportCtx = sessionExportCtx - val ctx = Ctx.empty given Ctx = ctx @@ -2345,38 +2460,15 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ).fold(0)(_.memType.lim.min) def compiledModule(entryName: Str): CompiledWasmModule = - CompiledWasmModule(ctx.toWat, entryName, systemMemMinPages, sessionExportCtx.collectedBindings.toSeq) + CompiledWasmModule(ctx.toWat, entryName, systemMemMinPages, sessionExports) - // Create the two Wasm intrinsic struct types shared across class lowering: - // the base RTTI layout (`TypeInfoBase`) and the base object layout (`Object`). - ctx.addType(TypeInfo( - sym = typeInfoBaseSym, - compType = StructType(Seq( - tagFieldSym -> Field(I32Type, mutable = false, id = "$tag"), - parentFieldSym -> Field( - RefType(TypeIdx(SymIdx("TypeInfoBase")), nullable = true), - mutable = false, - id = "$parent", - ), - )), - objectTag = N, - )) - - ctx.addType(TypeInfo( - sym = baseObjectSym, - compType = StructType(Seq( - typeInfoFieldSym -> Field(RefType(typeInfoBaseTypeIdx, nullable = false), mutable = true, id = "$typeinfo"), - )), - objectTag = S(ctx.getFreshObjectTag() ensuring (_ == 0)), - )) + registerBaseRuntimeTypes() registerSessionImports(sessionImports) - collectSessionGlobalSymbols( - p.main, - sessionExportCtx, - ).toSeq.sortBy(_.uid).foreach: sym => - registerSessionGlobal(sym) + sessionExports.foreach: + case global: SessionGlobal => registerSessionGlobal(global.sym) + case _ => () boundary[CompiledWasmModule]: val outerRaise = summon[Raise] @@ -2390,7 +2482,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case _: ErrorReport => break(compiledModule("entry")) case _ => () val ordered = sortTopLevelClasses(collectTopLevelClassDefns(p.main)) - ordered.foreach(predeclareClass) + ordered.foreach(predeclareClass(_, preservedSessionSymbols)) // Compile the entry function under a dedicated local scope so that any temp locals introduced // during codegen (e.g., via `local.tee`) are declared in the entry function. @@ -2462,20 +2554,20 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: def nonNestedScoped( blk: Block, - )(k: Block => Expr)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = blk match + )(k: Block => Expr)(using Ctx, FunctionCtx, Raise): Expr = blk match case Scoped(syms, body) => blockPreamble(syms.view.filter(body.freeVars)) k(body) case _ => k(blk) - def block(t: Block)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = + def block(t: Block)(using Ctx, FunctionCtx, Raise): Expr = nonNestedScoped(t)(returningTerm) def setupFunction( thisParam: Opt[InnerSymbol], params: ParamList, body: Block, - )(using Ctx, Raise, SessionExportCtx): (Expr, FunctionCtx) = + )(using Ctx, Raise): (Expr, FunctionCtx) = genFuncBody(params :: Nil, thisSym = thisParam): block(body) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index 809cde60b1..c3bd2971f5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -392,8 +392,18 @@ extends Importer with ucs.SplitElaborator: | Keyword.`private` )) => S(Annot.Modifier(kw)) case App(Ident("config"), Tup(args)) => - val modify = ConfigParser.parseOverrides(args) - S(Annot.Config(modify)) + // Config target flag override is not allowed + val hasTargetOverride = args.collectFirst: + case InfixApp(Ident("target"), Keywrd(Keyword.`:`), _) => () + .nonEmpty + if hasTargetOverride then + raise(ErrorReport( + msg"`target` can only be configured with a top-level #config(...) directive" -> tree.toLoc :: Nil, + source = Diagnostic.Source.Compilation)) + N + else + val modify = ConfigParser.parseOverrides(args) + S(Annot.Config(modify)) case _ => term(tree) match case Term.Error => N case trm => diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala index f48590b82b..ef41d275bf 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala @@ -54,7 +54,7 @@ class Importer: val importedSym = tl.trace(s">>> Importing $file"): given TL = tl - val artifact = cctx.getElaboratedBlock(file, prelude) + val artifact = cctx.getElaboratedBlock(file, prelude, cfg.target is CompilationTarget.Wasm) artifact.tree.definedSymbols.find(_._1 === nme) match case Some(nme -> imsym) => imsym case None => lastWords(s"File $file does not define a symbol named $nme") diff --git a/hkmc2/shared/src/test/mlscript-compile/.gitignore b/hkmc2/shared/src/test/mlscript-compile/.gitignore index 0d28e12d0c..61b190887c 100644 --- a/hkmc2/shared/src/test/mlscript-compile/.gitignore +++ b/hkmc2/shared/src/test/mlscript-compile/.gitignore @@ -1,4 +1,5 @@ *.mjs +*.wat !/Predef.mjs !/Runtime.mjs !/RuntimeJS.mjs diff --git a/hkmc2/shared/src/test/mlscript-compile/wasm/ImportValue.mls b/hkmc2/shared/src/test/mlscript-compile/wasm/ImportValue.mls new file mode 100644 index 0000000000..6f753b8cc8 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/wasm/ImportValue.mls @@ -0,0 +1,7 @@ +#config(target: CompilationTarget.Wasm) + +import "./SimpleTarget.mls" + +module ImportValue with ... + +val value = SimpleTarget.value + 1 diff --git a/hkmc2/shared/src/test/mlscript-compile/wasm/SimpleTarget.mjs b/hkmc2/shared/src/test/mlscript-compile/wasm/SimpleTarget.mjs new file mode 100644 index 0000000000..7778558f04 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/wasm/SimpleTarget.mjs @@ -0,0 +1,21 @@ +import { __mlx_compileWatFromUrl, __mlx_buildSystem } from "./RuntimeWASM.mjs" + + +const __mlx_wat_url = new URL("./SimpleTarget.wat", import.meta.url) + +async function __mlx_importObject() { + const importObject = { + system: await __mlx_buildSystem(0) + } + + return importObject +} + +const __mlx_wasmPromise = __mlx_importObject() + .then(importObject => __mlx_compileWatFromUrl(__mlx_wat_url, importObject)) + +export const __mlx_wasm = () => __mlx_wasmPromise + +const { instance } = await __mlx_wasm() + +export default instance.exports["SimpleTarget"]() diff --git a/hkmc2/shared/src/test/mlscript-compile/wasm/SimpleTarget.mls b/hkmc2/shared/src/test/mlscript-compile/wasm/SimpleTarget.mls new file mode 100644 index 0000000000..3e4d39fa0d --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/wasm/SimpleTarget.mls @@ -0,0 +1,5 @@ +#config(target: CompilationTarget.Wasm) + +module SimpleTarget with ... + +val value = 40 + 2 diff --git a/hkmc2/shared/src/test/mlscript-compile/wasm/SimpleTarget.wat b/hkmc2/shared/src/test/mlscript-compile/wasm/SimpleTarget.wat new file mode 100644 index 0000000000..d8847a1ed3 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/wasm/SimpleTarget.wat @@ -0,0 +1,72 @@ +(module + (type $TypeInfoBase (sub (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) + (type $Object (sub (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) + (type $SimpleTarget_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) + (type $SimpleTarget (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) + (type $SimpleTarget_init (func (param $this (ref null any)) (result (ref null any)))) + (type $SimpleTarget_ctor (func (result (ref null any)))) + (type $Unit_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) + (type $Unit (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) + (type $Unit_init (func (param $this (ref null any)) (result (ref null any)))) + (type $Unit_ctor (func (result (ref null any)))) + (type $entry (func (result (ref null any)))) + (type $start (func)) + (global $SimpleTarget_typeinfo (export "SimpleTarget_typeinfo") (ref $SimpleTarget_typeinfo) (struct.new $SimpleTarget_typeinfo + (i32.const 1) + (ref.null $TypeInfoBase))) + (global $Unit_typeinfo (export "Unit_typeinfo") (ref $Unit_typeinfo) (struct.new $Unit_typeinfo + (i32.const 2) + (ref.null $TypeInfoBase))) + (global $Unit$inst (export "Unit$inst") (mut (ref null $Unit)) (ref.null $Unit)) + (func $SimpleTarget_init (type $SimpleTarget_init) (param $this (ref null any)) (result (ref null any)) + (block (result (ref null any)) + (nop) + (nop) + (return + (local.get $this)))) + (func $SimpleTarget_ctor (export "SimpleTarget") (type $SimpleTarget_ctor) (result (ref null any)) + (local $this (ref null any)) + (block (result (ref null any)) + (local.set $this + (struct.new $SimpleTarget + (global.get $SimpleTarget_typeinfo))) + (drop + (call $SimpleTarget_init + (local.get $this))) + (return + (local.get $this)))) + (func $Unit_init (type $Unit_init) (param $this (ref null any)) (result (ref null any)) + (block (result (ref null any)) + (nop) + (nop) + (return + (local.get $this)))) + (func $Unit_Unit (export "Unit_ctor") (type $Unit_ctor) (result (ref null any)) + (local $this (ref null any)) + (block (result (ref null any)) + (local.set $this + (struct.new $Unit + (global.get $Unit_typeinfo))) + (drop + (call $Unit_init + (local.get $this))) + (return + (local.get $this)))) + (func $start (type $start) + (block + (global.set $Unit$inst + (ref.cast (ref null $Unit) + (call $Unit_Unit))))) + (func $entry (export "entry") (type $entry) (result (ref null any)) + (block (result (ref null any)) + (block + (nop) + (nop)) + (global.get $Unit$inst))) + (elem $SimpleTarget_init declare func $SimpleTarget_init) + (elem $SimpleTarget_ctor declare func $SimpleTarget_ctor) + (elem $Unit_init declare func $Unit_init) + (elem $Unit_Unit declare func $Unit_Unit) + (elem $start declare func $start) + (elem $entry declare func $entry) + (start $start)) \ No newline at end of file diff --git a/hkmc2/shared/src/test/mlscript/codegen/ImportMLsJS.mls b/hkmc2/shared/src/test/mlscript/codegen/ImportMLsJS.mls index 8adf693f50..9ec9535106 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ImportMLsJS.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ImportMLsJS.mls @@ -1,21 +1,10 @@ :js -import "../../mlscript-compile/Option.mjs" -//│ Option = class Option { -//│ Some: fun Some { class: class Some }, -//│ None: None, -//│ Both: fun Both { class: class Both }, -//│ unsafe: class unsafe -//│ } +import "../../mlscript-compile/Option.mls" -:e Option.isDefined(new Option.Some(1)) -//│ ╔══[COMPILATION ERROR] Expected a statically known class; found selection. -//│ ║ l.14: Option.isDefined(new Option.Some(1)) -//│ ║ ^^^^^^^^^^^ -//│ ╙── The 'new' keyword requires a statically known class; use the 'new!' operator for dynamic instantiation. //│ = true Option.isDefined(new! Option.Some(1)) @@ -53,3 +42,7 @@ new! Option.Some(1) //│ = Some(1) +import "../../mlscript-compile/wasm/SimpleTarget.mls" + +SimpleTarget.value + 1 +//│ = NaN diff --git a/hkmc2/shared/src/test/mlscript/wasm/ControlFlow.mls b/hkmc2/shared/src/test/mlscript/wasm/ControlFlow.mls index 22fac443bd..bb2781fdea 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ControlFlow.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ControlFlow.mls @@ -36,7 +36,7 @@ let i = 0 in //│ (nop) //│ (return //│ (local.get $this)))) -//│ (func $Unit_Unit (type $Unit_ctor) (result (ref null any)) +//│ (func $Unit_Unit (export "Unit_ctor") (type $Unit_ctor) (result (ref null any)) //│ (local $this (ref null any)) //│ (block (result (ref null any)) //│ (local.set $this diff --git a/hkmc2/shared/src/test/mlscript/wasm/ReplImports.mls b/hkmc2/shared/src/test/mlscript/wasm/ReplImports.mls index 74aaf0a408..7aeba3bd93 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ReplImports.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ReplImports.mls @@ -59,7 +59,9 @@ unit //│ (type $Object (sub (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) //│ (type $Unit (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) //│ (type $Unit_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) +//│ (type $Unit_Unit (func (result (ref null any)))) //│ (type $entry (func (result (ref null any)))) +//│ (import "repl" "Unit_ctor" (func $Unit_Unit (type $Unit_Unit))) //│ (import "repl" "Unit_typeinfo" (global $Unit_typeinfo (ref $Unit_typeinfo))) //│ (import "repl" "Unit$inst" (global $Unit (mut (ref null $Unit)))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) diff --git a/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls b/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls index b0ca76e7a4..883bc35ed8 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls @@ -24,7 +24,7 @@ //│ (nop) //│ (return //│ (local.get $this)))) -//│ (func $Unit_Unit (type $Unit_ctor) (result (ref null any)) +//│ (func $Unit_Unit (export "Unit_ctor") (type $Unit_ctor) (result (ref null any)) //│ (local $this (ref null any)) //│ (block (result (ref null any)) //│ (local.set $this diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/WasmDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/WasmDiffMaker.scala index ca24478367..5de82f212a 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/WasmDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/WasmDiffMaker.scala @@ -103,7 +103,14 @@ abstract class WasmDiffMaker extends LlirDiffMaker: bindings.foreach: (bindingKey, binding) => sessionImports.update(bindingKey, binding) val CompiledWasmModule(modWat, mainFnNme, systemMemMinPages, sessionExports) = ltl.givenIn: - WatBuilder().program(pgrm, N, wd, sessionImports.values.toSeq, symbolsToPreserve) + WatBuilder().program( + pgrm, + N, + wd, + sessionImports.values.toSeq, + symbolsToPreserve, + SessionBinding.ReplModuleName, + ) val modWatJsLit = JSBuilder.makeStringLiteral(modWat.mkString(output.ColWidth)) if wat.isSet then