diff --git a/README.md b/README.md index 7c3516ade2..29b1f50534 100644 --- a/README.md +++ b/README.md @@ -28,12 +28,44 @@ The entrypoint for the compiler is in `src/stanc/stanc.ml` which sequences the v ### Distinct stanc Phases The phases of stanc are summarized in the following information flowchart and list. - + ![stanc3 information flow](docs/img/information-flow.png) 1. [Lex](src/frontend/lexer.mll) the Stan language into tokens. 1. [Parse](src/frontend/parser.mly) Stan language into AST that represents the syntax quite closely and aides in development of pretty-printers and linters. `stanc --debug-ast` to print this out. -1. Typecheck & add type information [Semantic_check.ml](src/frontend/Semantic_check.ml). `stanc --debug-decorated-ast` +1. Typecheck & add type information [Typechecker.ml](src/frontend/Typechecker.ml). `stanc --debug-decorated-ast` 1. [Lower](src/frontend/Ast_to_Mir.ml) into [Middle Intermediate Representation](src/middle/Mir.ml) (AST -> MIR) `stanc --debug-mir` (or `--debug-mir-pretty`) 1. Analyze & optimize (MIR -> MIR) 1. Backend MIR transform (MIR -> MIR) [Transform_Mir.ml](src/stan_math_backend/Transform_Mir.ml) `stanc --debug-transformed-mir` diff --git a/docs/img/information-flow.png b/docs/img/information-flow.png index 7b37f4b0db..2956f7abad 100644 Binary files a/docs/img/information-flow.png and b/docs/img/information-flow.png differ diff --git a/src/common/Common.ml b/src/common/Common.ml index 44d102e602..690a60392d 100644 --- a/src/common/Common.ml +++ b/src/common/Common.ml @@ -2,7 +2,6 @@ module Foldable = Foldable (* Other signatures *) -module Validation = Validation module Pretty = Pretty (* 'Two-level type' signatures and functors *) diff --git a/src/common/Common.mli b/src/common/Common.mli index 4b66db925d..1901c4f36f 100644 --- a/src/common/Common.mli +++ b/src/common/Common.mli @@ -4,7 +4,6 @@ module Foldable : module type of Foldable (** Other signatures *) -module Validation : module type of Validation module Pretty : module type of Pretty (** 'Two-level type' signatures and functors *) diff --git a/src/common/Validation.ml b/src/common/Validation.ml deleted file mode 100644 index 6f94a53af8..0000000000 --- a/src/common/Validation.ml +++ /dev/null @@ -1,167 +0,0 @@ -include Core_kernel - -module type Infix = sig - type 'a t - - val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t -end - -module type S = sig - type error - type 'a t - - val map : 'a t -> f:('a -> 'b) -> 'b t - val pure : 'a -> 'a t - val apply : 'a t -> f:('a -> 'b) t -> 'b t - val apply_const : 'a t -> 'b t -> 'b t - val bind : 'a t -> f:('a -> 'b t) -> 'b t - val liftA2 : ('a -> 'b -> 'c) -> 'a t -> 'b t -> 'c t - val liftA3 : ('a -> 'b -> 'c -> 'd) -> 'a t -> 'b t -> 'c t -> 'd t - val sequence : 'a t list -> 'a list t - val ok : 'a -> 'a t - val error : error -> _ t - val is_error : 'a t -> bool - val is_success : 'a t -> bool - val get_errors_opt : 'a t -> error list option - val get_first_error_opt : 'a t -> error option - val get_success_opt : 'a t -> 'a option - - val get_with : - 'a t -> with_ok:('a -> 'b) -> with_errors:(error list -> 'b) -> 'b - - val to_result : 'a t -> ('a, error list) result - - module Validation_infix : Infix with type 'a t := 'a t - include Infix with type 'a t := 'a t -end - -(** - It is common need in a compiler to check for certain errors e.g. syntax - errors, type errors, other semantic errors, etc. Some typical ways of doing - this include exceptions or representing success/failure with `Result.t`. - In both of these approaches, the computation will 'fail' when we reach the - first error condition. - - In some situations, it may be preferable to report _all_ of these errors to a - user in one go. `Validation` gives us this ability*. - - The implementation below is an example of an OCaml `functor` i.e. a module - that is _parametrized_ by another module. Our `Valiation.Make` functor is - parameterized over another module containing the type of errors. - - When applied, the functor returns a module with signature `Validation.S` but - with the underlying error type fixed at the type `t` contained in the module `X`. - As an example, we use this functor with the module `Semantic_error` - (defined in `frontend`) so the type `error` is made equal to `Semantic_error.t` - - Under the hood, `Validation` is just a `Result.t` with a `NonEmpty` list of - `error`s. The main workhorse of the module is the `apply` function. - The function accepts a value lifted into our result type `'a t` and a - _function_ lifted into the result type `('a -> 'b) t`. - - ``` - let apply x ~f = - match (f,x) with - ``` - - `apply` 'unwraps' the function and the value and performs different actions - depending on whether an error has occured. Because we have two arguments each - of which can be in two states, we have to check for four conditions: - - The first condition is that both the function `f` and the value `x` are `Ok`; - in this case we can simply apply the function to the value and lift is back - into the result type: - - ``` - | Ok f , Ok x -> Ok (f x) - ``` - - The second situation is the the function is an `Error`; in this case we can't - apply the function so we simply return the same error: - - ``` - | Error e , Ok _ -> Error e - ``` - - The third situation is that the function `f` is `Ok` but he value `x` is an - `Error`; again, we can't apply the function so we return the error: - - ``` - | Ok _ , Error e -> Error e - ``` - - The final situation is that both `f` and `x` are errors; because our error - type is a `NonEmpty` list, we can combine them using the `append` function and - track both: - - ``` - | Error e1 , Error e2 -> Error(append e1 e2) - ``` - - The module hides the implementation of `'a t` but exposes the functions `ok` - and `error` to construct this type. - - * This is true so long as we restrict our usage to `apply` (and related) - functions i.e. the _applicative_ interface.As soon as we use `bind` - (i.e. the _monadic_ interface) we lose this ability. Technically, `Validation` - doesn't constitute a monad since it violates the law: - - ``` - liftM2 f m1 m2 === apply f a1 s2 - ``` -*) -module Make (X : sig - type t -end) : S with type error := X.t = struct - type errors = NonEmpty of X.t * X.t list - - let append xs ys = - match (xs, ys) with - | NonEmpty (x, []), NonEmpty (y, []) -> NonEmpty (x, [y]) - | NonEmpty (x, xs), NonEmpty (y, ys) -> NonEmpty (x, xs @ (y :: ys)) - - type 'a t = ('a, errors) result - - let map x ~f = match x with Ok x -> Ok (f x) | Error x -> Error x - let pure x = Ok x - - let apply x ~f = - match (f, x) with - | Ok f, Ok x -> Ok (f x) - | Error e, Ok _ -> Error e - | Ok _, Error e -> Error e - | Error e1, Error e2 -> Error (append e1 e2) - - let apply_const a b = apply b ~f:(map a ~f:(fun _ x -> x)) - let bind x ~f = match x with Ok x -> f x | Error e -> Error e - let liftA2 f x y = apply y ~f:(apply x ~f:(pure f)) - let liftA3 f x y z = apply z ~f:(apply y ~f:(apply x ~f:(pure f))) - let consA next rest = liftA2 List.cons next rest - let sequence ts = List.fold_right ~init:(pure []) ~f:consA ts - - module Validation_infix = struct let ( >>= ) x f = bind x ~f end - include Validation_infix - - let ok x = pure x - let error x = Error (NonEmpty (x, [])) - let is_error = function Error _ -> true | _ -> false - let is_success = function Ok _ -> true | _ -> false - - let get_errors_opt = function - | Error (NonEmpty (x, xs)) -> Some (x :: xs) - | _ -> None - - let get_first_error_opt = function - | Error (NonEmpty (x, _)) -> Some x - | _ -> None - - let get_success_opt = function Ok x -> Some x | _ -> None - - let get_with x ~with_ok ~with_errors = - match x with - | Ok x -> with_ok x - | Error (NonEmpty (x, xs)) -> with_errors @@ (x :: xs) - - let to_result x = - match x with Ok x -> Ok x | Error (NonEmpty (x, xs)) -> Error (x :: xs) -end diff --git a/src/common/Validation.mli b/src/common/Validation.mli deleted file mode 100644 index d158a3251c..0000000000 --- a/src/common/Validation.mli +++ /dev/null @@ -1,38 +0,0 @@ -module type Infix = sig - type 'a t - - val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t -end - -module type S = sig - type error - type 'a t - - val map : 'a t -> f:('a -> 'b) -> 'b t - val pure : 'a -> 'a t - val apply : 'a t -> f:('a -> 'b) t -> 'b t - val apply_const : 'a t -> 'b t -> 'b t - val bind : 'a t -> f:('a -> 'b t) -> 'b t - val liftA2 : ('a -> 'b -> 'c) -> 'a t -> 'b t -> 'c t - val liftA3 : ('a -> 'b -> 'c -> 'd) -> 'a t -> 'b t -> 'c t -> 'd t - val sequence : 'a t list -> 'a list t - val ok : 'a -> 'a t - val error : error -> _ t - val is_error : 'a t -> bool - val is_success : 'a t -> bool - val get_errors_opt : 'a t -> error list option - val get_first_error_opt : 'a t -> error option - val get_success_opt : 'a t -> 'a option - - val get_with : - 'a t -> with_ok:('a -> 'b) -> with_errors:(error list -> 'b) -> 'b - - val to_result : 'a t -> ('a, error list) result - - module Validation_infix : Infix with type 'a t := 'a t - include Infix with type 'a t := 'a t -end - -module Make (X : sig - type t -end) : S with type error := X.t diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index 036dcf490f..f92a135f1d 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -878,5 +878,5 @@ let trans_prog filename (p : Ast.typed_program) : Program.Typed.t = ; generate_quantities ; transform_inits ; output_vars - ; prog_name= normalize_prog_name !Semantic_check.model_name + ; prog_name= normalize_prog_name !Typechecker.model_name ; prog_path= filename } diff --git a/src/frontend/Deprecation_analysis.mli b/src/frontend/Deprecation_analysis.mli index f56f288468..87cb9210ea 100644 --- a/src/frontend/Deprecation_analysis.mli +++ b/src/frontend/Deprecation_analysis.mli @@ -1,7 +1,7 @@ open Core_kernel open Ast -type t = Middle.Warnings.t +type t = Warnings.t val find_udf_log_suffix : typed_statement -> (string * Middle.UnsizedType.t) option diff --git a/src/frontend/Environment.ml b/src/frontend/Environment.ml new file mode 100644 index 0000000000..bff4284645 --- /dev/null +++ b/src/frontend/Environment.ml @@ -0,0 +1,48 @@ +open Core_kernel +open Middle + +type originblock = + | MathLibrary + | Functions + | Data + | TData + | Param + | TParam + | Model + | GQuant +[@@deriving sexp] + +type varinfo = {origin: originblock; global: bool; readonly: bool} +[@@deriving sexp] + +type info = + { type_: UnsizedType.t + ; kind: + [ `Variable of varinfo + | `UserDeclared of Location_span.t + | `StanMath + | `UserDefined ] } +[@@deriving sexp] + +type t = info list String.Map.t + +let create () = + let functions = + Hashtbl.to_alist Stan_math_signatures.stan_math_signatures + |> List.map ~f:(fun (key, values) -> + ( key + , List.map values ~f:(fun (rt, args, mem) -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name key, mem) + in + {type_; kind= `StanMath} ) ) ) + |> String.Map.of_alist_exn + in + functions + +let add env key type_ kind = Map.add_multi env ~key ~data:{type_; kind} +let set_raw env key data = Map.set env ~key ~data +let find env key = Map.find_multi env key +let mem env key = Map.mem env key +let iter env f = Map.iter env ~f diff --git a/src/frontend/Environment.mli b/src/frontend/Environment.mli new file mode 100644 index 0000000000..74fdb80c87 --- /dev/null +++ b/src/frontend/Environment.mli @@ -0,0 +1,53 @@ +(** Type environments used during typechecking. Maps from strings to function or variable information *) + +open Middle + +(** Origin blocks, to keep track of where variables are declared *) +type originblock = + | MathLibrary + | Functions + | Data + | TData + | Param + | TParam + | Model + | GQuant +[@@deriving sexp] + +(** Information available for each variable *) +type varinfo = {origin: originblock; global: bool; readonly: bool} +[@@deriving sexp] + +type info = + { type_: UnsizedType.t + ; kind: + [ `Variable of varinfo + | `UserDeclared of Location_span.t + | `StanMath + | `UserDefined ] } +[@@deriving sexp] + +type t + +val create : unit -> t +(** Return a new type environment which contains the Stan math library functions +*) + +val find : t -> string -> info list + +val add : + t + -> string + -> Middle.UnsizedType.t + -> [ `UserDeclared of Location_span.t + | `StanMath + | `UserDefined + | `Variable of varinfo ] + -> t +(** Add a new item to the type environment. Does not overwrite existing, but shadows *) + +val set_raw : t -> string -> info list -> t +(** Overwrite the existing items bound to a name *) + +val mem : t -> string -> bool +val iter : t -> (info list -> unit) -> unit diff --git a/src/middle/Errors.ml b/src/frontend/Errors.ml similarity index 83% rename from src/middle/Errors.ml rename to src/frontend/Errors.ml index 03fa5ff827..a06e10e727 100644 --- a/src/middle/Errors.ml +++ b/src/frontend/Errors.ml @@ -5,9 +5,9 @@ module Str = Re.Str (** Our type of syntax error information *) type syntax_error = - | Lexing of string * Location.t - | Include of string * Location.t - | Parsing of string * Location_span.t + | Lexing of string * Middle.Location.t + | Include of string * Middle.Location.t + | Parsing of string * Middle.Location_span.t (** Exception for Syntax Errors *) exception SyntaxError of syntax_error @@ -31,13 +31,13 @@ let fatal_error ?(msg = "") _ = let pp_context_with_message ppf (msg, loc) = Fmt.pf ppf "%a@,%s" (Fmt.option Fmt.string) - (Location.context_to_string loc) + (Middle.Location.context_to_string loc) msg let pp_semantic_error ?printed_filename ppf err = let loc_span = Semantic_error.location err in Fmt.pf ppf "Semantic error in %s:@;%a" - (Location_span.to_string ?printed_filename loc_span) + (Middle.Location_span.to_string ?printed_filename loc_span) pp_context_with_message (Fmt.strf "%a@." Semantic_error.pp err, loc_span.begin_loc) @@ -45,18 +45,18 @@ let pp_semantic_error ?printed_filename ppf err = let pp_syntax_error ?printed_filename ppf = function | Parsing (message, loc_span) -> Fmt.pf ppf "Syntax error in %s, parsing error:@,%a" - (Location_span.to_string ?printed_filename loc_span) + (Middle.Location_span.to_string ?printed_filename loc_span) pp_context_with_message (message, loc_span.begin_loc) | Lexing (_, loc) -> Fmt.pf ppf "Syntax error in %s, lexing error:@,%a@." - (Location.to_string ?printed_filename + (Middle.Location.to_string ?printed_filename {loc with col_num= loc.col_num - 1}) pp_context_with_message ("Invalid character found.", loc) | Include (message, loc) -> Fmt.pf ppf "Syntax error in %s, include error:@,%a@." - (Location.to_string loc ?printed_filename) + (Middle.Location.to_string loc ?printed_filename) pp_context_with_message (message, loc) let pp ?printed_filename ppf = function diff --git a/src/middle/Errors.mli b/src/frontend/Errors.mli similarity index 89% rename from src/middle/Errors.mli rename to src/frontend/Errors.mli index 8b3cc78d91..1ecd70b79d 100644 --- a/src/middle/Errors.mli +++ b/src/frontend/Errors.mli @@ -2,9 +2,9 @@ (** Our type of syntax error information *) type syntax_error = - | Lexing of string * Location.t - | Include of string * Location.t - | Parsing of string * Location_span.t + | Lexing of string * Middle.Location.t + | Include of string * Middle.Location.t + | Parsing of string * Middle.Location_span.t (** Exception for Syntax Errors *) exception SyntaxError of syntax_error diff --git a/src/frontend/Frontend_utils.ml b/src/frontend/Frontend_utils.ml index 244dfa2838..565a101298 100644 --- a/src/frontend/Frontend_utils.ml +++ b/src/frontend/Frontend_utils.ml @@ -1,29 +1,18 @@ open Core_kernel -module Warnings = Middle.Warnings -module Errors = Middle.Errors -module Semantic_error = Middle.Semantic_error let untyped_ast_of_string s = let res, warnings = Parse.parse_string Parser.Incremental.program s in Fmt.epr "%a" (Fmt.list ~sep:Fmt.nop Warnings.pp) warnings ; res -let emit_warnings_and_return_ast (ast, warnings) = - Warnings.pp_warnings Fmt.stderr warnings ; - ast - let typed_ast_of_string_exn s = Result.( untyped_ast_of_string s >>= fun ast -> - Semantic_check.semantic_check_program ast - |> map_error ~f:(function - | err :: _ -> Errors.Semantic_error err - | [] -> - failwith - "Internal compiler error: no message from Semantic_check." )) + Typechecker.check_program ast + |> map_error ~f:(fun e -> Errors.Semantic_error e)) |> Result.map_error ~f:Errors.to_string - |> Result.ok_or_failwith |> emit_warnings_and_return_ast + |> Result.ok_or_failwith |> fst let get_ast_or_exit ?printed_filename ?(print_warnings = true) filename = let res, warnings = Parse.parse_file Parser.Incremental.program filename in @@ -34,17 +23,10 @@ let get_ast_or_exit ?printed_filename ?(print_warnings = true) filename = | Result.Error err -> Errors.pp Fmt.stderr err ; exit 1 let type_ast_or_exit ast = - try - match Semantic_check.semantic_check_program ast with - | Result.Ok p -> emit_warnings_and_return_ast p - | Result.Error (error :: _) -> - Errors.pp_semantic_error Fmt.stderr error ; - exit 1 - | Result.Error [] -> - Printf.eprintf - "Semantic check failed but reported no errors. This should never \ - happen." ; - exit 1 - with Errors.SemanticError err -> - Errors.pp_semantic_error Fmt.stderr err ; - exit 1 + match Typechecker.check_program ast with + | Result.Ok (p, warns) -> + Warnings.pp_warnings Fmt.stderr warns ; + p + | Result.Error error -> + Errors.pp_semantic_error Fmt.stderr error ; + exit 1 diff --git a/src/frontend/Info.mli b/src/frontend/Info.mli index 3c9898b72d..a21aa8f717 100644 --- a/src/frontend/Info.mli +++ b/src/frontend/Info.mli @@ -10,6 +10,7 @@ - [type]: the base type of the variable (["int"] or ["real"]). - [dimensions]: the number of dimensions ([0] for a scalar, [1] for a vector or row vector, etc.). + The JSON object also have the fields [stanlib_calls] and [distributions] containing the name of the standard library functions called and distributions used. diff --git a/src/frontend/Input_warnings.mli b/src/frontend/Input_warnings.mli index 51cb28fe05..e6fd106e65 100644 --- a/src/frontend/Input_warnings.mli +++ b/src/frontend/Input_warnings.mli @@ -2,7 +2,7 @@ val init : unit -> unit (** As something of a hack, Input_warnings keeps track of which warnings the lexer has emitted as a form of hidden state, which must be initialized and [collect]ed.*) -val collect : unit -> Middle.Warnings.t list +val collect : unit -> Warnings.t list (** Returns all of the warnings issued since [init] was called. *) val add_warning : Middle.Location_span.t -> string -> unit diff --git a/src/frontend/Parse.mli b/src/frontend/Parse.mli index 3033347878..61699275b7 100644 --- a/src/frontend/Parse.mli +++ b/src/frontend/Parse.mli @@ -5,13 +5,13 @@ open Core_kernel val parse_file : (Lexing.position -> Ast.untyped_program Parser.MenhirInterpreter.checkpoint) -> string - -> (Ast.untyped_program, Middle.Errors.t) result * Middle.Warnings.t list + -> (Ast.untyped_program, Errors.t) result * Warnings.t list (** A helper function to take a parser, a filename and produce an AST. Under the hood, it takes care of Menhir's custom syntax error messages. *) val parse_string : (Lexing.position -> Ast.untyped_program Parser.MenhirInterpreter.checkpoint) -> string - -> (Ast.untyped_program, Middle.Errors.t) result * Middle.Warnings.t list + -> (Ast.untyped_program, Errors.t) result * Warnings.t list (** A helper function to take a parser, a string and produce an AST. Under the hood, it takes care of Menhir's custom syntax error messages. *) diff --git a/src/frontend/Preprocessor.ml b/src/frontend/Preprocessor.ml index 6a0bbfcc20..3235a25379 100644 --- a/src/frontend/Preprocessor.ml +++ b/src/frontend/Preprocessor.ml @@ -17,7 +17,7 @@ let rec try_open_in paths fname pos = match paths with | [] -> raise - (Middle.Errors.SyntaxError + (Errors.SyntaxError (Include ( "Could not find include file " ^ fname ^ " in specified include paths." @@ -50,7 +50,7 @@ let try_get_new_lexbuf fname pos = new_lexbuf.lex_curr_p <- new_lexbuf.lex_start_p ; if dup_exists (Str.split (Str.regexp ", included from\n") path) then raise - (Middle.Errors.SyntaxError + (Errors.SyntaxError (Include ( Printf.sprintf "File %s recursively included itself." fname , Middle.Location.of_position_exn diff --git a/src/frontend/Pretty_printing.ml b/src/frontend/Pretty_printing.ml index 6bace6b092..343e2ba789 100644 --- a/src/frontend/Pretty_printing.ml +++ b/src/frontend/Pretty_printing.ml @@ -268,7 +268,7 @@ and pp_expression ppf ({expr= e_content; emeta= {loc; _}} : untyped_expression) ) | CondDistApp (_, id, es) -> ( match es with - | [] -> Middle.Errors.fatal_error () + | [] -> Errors.fatal_error () | e :: es' -> with_hbox ppf (fun () -> Fmt.pf ppf "%a(%a | %a)" pp_identifier id pp_expression e @@ -326,7 +326,7 @@ let pp_sizedtype ppf = function | SRowVector (_, e) -> Fmt.pf ppf "row_vector[%a]" pp_expression e | SMatrix (_, e1, e2) -> Fmt.pf ppf "matrix[%a, %a]" pp_expression e1 pp_expression e2 - | SArray _ -> raise (Middle.Errors.FatalError "This should never happen.") + | SArray _ -> raise (Errors.FatalError "This should never happen.") let pp_transformation ppf = function | Middle.Transformation.Identity -> Fmt.pf ppf "" @@ -593,7 +593,7 @@ let pp_program ppf pp_block_list ppf blocks let check_correctness prog pretty = - let result_ast, (_ : Middle.Warnings.t list) = + let result_ast, (_ : Warnings.t list) = Parse.parse_string Parser.Incremental.program pretty in if diff --git a/src/frontend/Semantic_check.ml b/src/frontend/Semantic_check.ml deleted file mode 100644 index f20268ae06..0000000000 --- a/src/frontend/Semantic_check.ml +++ /dev/null @@ -1,1947 +0,0 @@ -(** Semantic validation of AST*) - -(* Idea: check many of things related to identifiers that are hard to check - during parsing and are in fact irrelevant for building up the parse tree *) - -open Core_kernel -open Symbol_table -open Middle -open Ast -open Errors -module Validate = Common.Validation.Make (Semantic_error) - -let warnings = ref [] - -let add_warning (span : Location_span.t) (message : string) = - warnings := (span, message) :: !warnings - -let attach_warnings (x : typed_program) = (x, List.rev !warnings) - -(* There is a semantic checking function for each AST node that calls - the checking functions for its children left to right. *) - -(* Top level function semantic_check_program declares the AST while operating - on (1) a global symbol table vm, and (2) structure of type context_flags_record - to communicate information down the AST. *) - -let check_of_compatible_return_type rt1 srt2 = - UnsizedType.( - match (rt1, srt2) with - | Void, NoReturnType - |Void, Incomplete Void - |Void, Complete Void - |Void, AnyReturnType -> - true - | ReturnType UReal, Complete (ReturnType UInt) -> true - | ReturnType UComplex, Complete (ReturnType UReal) -> true - | ReturnType UComplex, Complete (ReturnType UInt) -> true - | ReturnType rt1, Complete (ReturnType rt2) -> rt1 = rt2 - | ReturnType _, AnyReturnType -> true - | _ -> false) - -(** Origin blocks, to keep track of where variables are declared *) -type originblock = - | MathLibrary - | Functions - | Data - | TData - | Param - | TParam - | Model - | GQuant - -(** Print all the signatures of a stan math operator, for the purposes of error messages. *) -let check_that_all_functions_have_definition = ref true - -let model_name = ref "" -let vm = Symbol_table.initialize () - -(* Record structure holding flags and other markers about context to be - used for error reporting. *) -type context_flags_record = - { current_block: originblock - ; in_toplevel_decl: bool - ; in_fun_def: bool - ; in_returning_fun_def: bool - ; in_rng_fun_def: bool - ; in_lp_fun_def: bool - ; in_udf_dist_def: bool - ; loop_depth: int } - -(* Some helper functions *) -let dup_exists l = - match List.find_a_dup ~compare:String.compare l with - | Some _ -> true - | None -> false - -let type_of_expr_typed ue = ue.emeta.type_ - -let calculate_autodifftype cf at ut = - match at with - | (Param | TParam | Model | Functions) - when not (UnsizedType.contains_int ut || cf.current_block = GQuant) -> - UnsizedType.AutoDiffable - | _ -> DataOnly - -let has_int_type ue = ue.emeta.type_ = UInt -let has_int_array_type ue = ue.emeta.type_ = UArray UInt - -let has_int_or_real_type ue = - match ue.emeta.type_ with UInt | UReal -> true | _ -> false - -let probability_distribution_name_variants id = - let name = id.name in - let open String in - List.map - ~f:(fun n -> {name= n; id_loc= id.id_loc}) - ( if name = "multiply_log" || name = "binomial_coefficient_log" then [name] - else if is_suffix ~suffix:"_lpmf" name then - [name; drop_suffix name 5 ^ "_lpdf"; drop_suffix name 5 ^ "_log"] - else if is_suffix ~suffix:"_lpdf" name then - [name; drop_suffix name 5 ^ "_lpmf"; drop_suffix name 5 ^ "_log"] - else if is_suffix ~suffix:"_lcdf" name then - [name; drop_suffix name 5 ^ "_cdf_log"] - else if is_suffix ~suffix:"_lccdf" name then - [name; drop_suffix name 6 ^ "_ccdf_log"] - else if is_suffix ~suffix:"_cdf_log" name then - [name; drop_suffix name 8 ^ "_lcdf"] - else if is_suffix ~suffix:"_ccdf_log" name then - [name; drop_suffix name 9 ^ "_lccdf"] - else if is_suffix ~suffix:"_log" name then - [name; drop_suffix name 4 ^ "_lpmf"; drop_suffix name 4 ^ "_lpdf"] - else [name] ) - -let lub_rt loc rt1 rt2 = - match (rt1, rt2) with - | UnsizedType.ReturnType UReal, UnsizedType.ReturnType UInt - |ReturnType UInt, ReturnType UReal -> - Validate.ok (UnsizedType.ReturnType UReal) - | _, _ when rt1 = rt2 -> Validate.ok rt2 - | _ -> Semantic_error.mismatched_return_types loc rt1 rt2 |> Validate.error - -(* -Checks that a variable/function name: - - if UDF that it does not match a Stan Math function - - a function/identifier does not have the _lupdf/_lupmf suffix - - is not already in use -*) -let check_fresh_variable_basic id is_udf = - Validate.( - if - is_udf - && ( Stan_math_signatures.is_stan_math_function_name id.name - (* variadic functions are currently not in math sigs *) - || Stan_math_signatures.is_reduce_sum_fn id.name - || Stan_math_signatures.is_variadic_ode_fn id.name ) - then Semantic_error.ident_is_stanmath_name id.id_loc id.name |> error - else if Utils.is_unnormalized_distribution id.name then - if is_udf then - Semantic_error.udf_is_unnormalized_fn id.id_loc id.name |> error - else - Semantic_error.ident_has_unnormalized_suffix id.id_loc id.name |> error - else - match Symbol_table.look vm id.name with - | Some _ -> Semantic_error.ident_in_use id.id_loc id.name |> error - | None -> ok ()) - -let check_fresh_variable id is_udf = - List.fold ~init:(Validate.ok ()) - ~f:(fun v0 name -> - check_fresh_variable_basic name is_udf |> Validate.apply_const v0 ) - (probability_distribution_name_variants id) - -(* == SEMANTIC CHECK OF PROGRAM ELEMENTS ==================================== *) - -(* Probably nothing to do here *) -let semantic_check_assignmentoperator op = Validate.ok op - -(* Probably nothing to do here *) -let semantic_check_autodifftype at = Validate.ok at - -(* Probably nothing to do here *) -let rec semantic_check_unsizedtype : UnsizedType.t -> unit Validate.t = - function - | UFun (l, rt, _, _) -> - (* fold over argument types accumulating errors with initial state - given by validating the return type *) - List.fold - ~f:(fun v0 (at, ut) -> - Validate.( - apply_const - (apply_const v0 (semantic_check_autodifftype at)) - (semantic_check_unsizedtype ut)) ) - ~init:(semantic_check_returntype rt) - l - | UArray ut -> semantic_check_unsizedtype ut - | _ -> Validate.ok () - -and semantic_check_returntype : UnsizedType.returntype -> unit Validate.t = - function - | Void -> Validate.ok () - | ReturnType ut -> semantic_check_unsizedtype ut - -(* -- Identifiers ---------------------------------------------------------- *) -let reserved_keywords = - [ "for"; "in"; "while"; "repeat"; "until"; "if"; "then"; "else"; "true" - ; "false"; "target"; "int"; "real"; "complex"; "void"; "vector"; "simplex" - ; "unit_vector"; "ordered"; "positive_ordered"; "row_vector"; "matrix" - ; "cholesky_factor_corr"; "cholesky_factor_cov"; "corr_matrix"; "cov_matrix" - ; "functions"; "model"; "data"; "parameters"; "quantities"; "transformed" - ; "generated"; "profile"; "return"; "break"; "continue"; "increment_log_prob" - ; "get_lp"; "print"; "reject"; "typedef"; "struct"; "var"; "export"; "extern" - ; "static"; "auto" ] - -let semantic_check_identifier id = - Validate.( - if id.name = !model_name then - Semantic_error.ident_is_model_name id.id_loc id.name |> error - else if - String.is_suffix id.name ~suffix:"__" - || List.exists ~f:(fun str -> str = id.name) reserved_keywords - then Semantic_error.ident_is_keyword id.id_loc id.name |> error - else ok ()) - -(* -- Operators ------------------------------------------------------------- *) -let semantic_check_operator _ = Validate.ok () - -(* == Expressions =========================================================== *) - -let arg_type x = (x.emeta.ad_level, x.emeta.type_) -let get_arg_types = List.map ~f:arg_type - -(* -- Function application -------------------------------------------------- *) - -let semantic_check_fn_conditioning ~loc id = - Validate.( - if - List.exists - ~f:(fun suffix -> String.is_suffix id.name ~suffix) - Utils.conditioning_suffices - && not (String.is_suffix id.name ~suffix:"_cdf") - then Semantic_error.conditioning_required loc |> error - else ok ()) - -(** `Target+=` can only be used in model and functions - with right suffix (same for tilde etc) -*) -let semantic_check_fn_target_plus_equals cf ~loc id = - Validate.( - if - String.is_suffix id.name ~suffix:"_lp" - && not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error - else ok ()) - -(** Rng functions cannot be used in Tp or Model and only - in function defs with the right suffix -*) -let semantic_check_fn_rng cf ~loc id = - Validate.( - if String.is_suffix id.name ~suffix:"_rng" && cf.in_toplevel_decl then - Semantic_error.invalid_decl_rng_fn loc |> error - else if - String.is_suffix id.name ~suffix:"_rng" - && ( (cf.in_fun_def && not cf.in_rng_fun_def) - || cf.current_block = TParam || cf.current_block = Model ) - then Semantic_error.invalid_rng_fn loc |> error - else ok ()) - -(** unnormalized _lpdf/_lpmf functions can only be used in _lpdf/_lpmf/_lp udfs - or the model block -*) -let semantic_check_unnormalized cf ~loc id = - Validate.( - if - Utils.is_unnormalized_distribution id.name - && not ((cf.in_fun_def && cf.in_udf_dist_def) || cf.current_block = Model) - then Semantic_error.invalid_unnormalized_fn loc |> error - else ok ()) - -let mk_fun_app ~is_cond_dist (x, y, z) = - if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z) - -(* Regular function application *) -let semantic_check_fn_normal ~is_cond_dist ~loc id es = - Validate.( - match Symbol_table.look vm (Utils.normalized_name id.name) with - | Some (_, UnsizedType.UFun (_, Void, _, _)) -> - Semantic_error.returning_fn_expected_nonreturning_found loc id.name - |> error - | Some (_, UFun (listedtypes, ReturnType ut, _, _)) -> ( - match - SignatureMismatch.check_compatible_arguments_mod_conv listedtypes - (get_arg_types es) - with - | Some x -> - es - |> List.map ~f:type_of_expr_typed - |> Semantic_error.illtyped_userdefined_fn_app loc id.name listedtypes - (ReturnType ut) x - |> error - | None -> - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (UserDefined (Fun_kind.suffix_from_name id.name), id, es)) - ~ad_level:(expr_ad_lub es) ~type_:ut ~loc - |> ok ) - | Some _ -> - (* Check that Funaps are actually functions *) - Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error - | None -> - Semantic_error.returning_fn_expected_undeclaredident_found loc id.name - |> error) - -(* Stan-Math function application *) -let semantic_check_fn_stan_math ~is_cond_dist ~loc id es = - match SignatureMismatch.stan_math_returntype id.name (get_arg_types es) with - | Ok Void -> - Semantic_error.returning_fn_expected_nonreturning_found loc id.name - |> Validate.error - | Ok (ReturnType ut) -> - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib (Fun_kind.suffix_from_name id.name), id, es)) - ~ad_level:(expr_ad_lub es) ~type_:ut ~loc - |> Validate.ok - | Error x -> - es - |> List.map ~f:(fun e -> e.emeta.type_) - |> Semantic_error.illtyped_stanlib_fn_app loc id.name x - |> Validate.error - -let semantic_check_reduce_sum ~is_cond_dist ~loc id es = - match es with - | { emeta= - { type_= - UnsizedType.UFun - (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _); _ - }; _ } - :: _ - when List.mem Stan_math_signatures.reduce_sum_slice_types - sliced_arg_fun_type ~equal:( = ) -> ( - let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in - let mandatory_fun_args = - [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] - in - match - SignatureMismatch.check_variadic_args true mandatory_args - mandatory_fun_args UReal (get_arg_types es) - with - | None -> - mk_typed_expression - ~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es)) - ~ad_level:(expr_ad_lub es) ~type_:UnsizedType.UReal ~loc - |> Validate.ok - | Some (expected_args, error) -> - Semantic_error.illtyped_reduce_sum loc id.name - (List.map ~f:type_of_expr_typed es) - expected_args error - |> Validate.error ) - | _ -> - let mandatory_args = - UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] - in - let mandatory_fun_args = - UnsizedType. - [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] - in - let expected_args, error = - SignatureMismatch.check_variadic_args true mandatory_args - mandatory_fun_args UReal (get_arg_types es) - |> Option.value_exn - in - Semantic_error.illtyped_reduce_sum_generic loc id.name - (List.map ~f:type_of_expr_typed es) - expected_args error - |> Validate.error - -let semantic_check_variadic_ode ~is_cond_dist ~loc id es = - let optional_tol_mandatory_args = - if Stan_math_signatures.variadic_ode_adjoint_fn = id.name then - Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types - else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn id.name then - Stan_math_signatures.variadic_ode_tol_arg_types - else [] - in - let mandatory_arg_types = - Stan_math_signatures.variadic_ode_mandatory_arg_types - @ optional_tol_mandatory_args - in - match - SignatureMismatch.check_variadic_args false mandatory_arg_types - Stan_math_signatures.variadic_ode_mandatory_fun_args - Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types es) - with - | None -> - mk_typed_expression - ~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es)) - ~ad_level:(expr_ad_lub es) - ~type_:Stan_math_signatures.variadic_ode_return_type ~loc - |> Validate.ok - | Some (expected_args, err) -> - Semantic_error.illtyped_variadic_ode loc id.name - (List.map ~f:type_of_expr_typed es) - expected_args err - |> Validate.error - -let fn_kind_from_application id es = - (* We need to check an application here, rather than a mere name of the - function because, technically, user defined functions can shadow - constants in StanLib. *) - let suffix = Fun_kind.suffix_from_name id.name in - if - Stan_math_signatures.stan_math_returntype id.name - (List.map ~f:(fun x -> (x.emeta.ad_level, x.emeta.type_)) es) - <> None - || Symbol_table.look vm id.name = None - && Stan_math_signatures.is_stan_math_function_name id.name - then StanLib suffix - else UserDefined suffix - -(** Determines the function kind based on the identifier and performs the - corresponding semantic check -*) -let semantic_check_fn ~is_cond_dist ~loc id es = - match fn_kind_from_application id es with - | StanLib FnPlain when Stan_math_signatures.is_reduce_sum_fn id.name -> - semantic_check_reduce_sum ~is_cond_dist ~loc id es - | StanLib FnPlain when Stan_math_signatures.is_variadic_ode_fn id.name -> - semantic_check_variadic_ode ~is_cond_dist ~loc id es - | StanLib _ -> semantic_check_fn_stan_math ~is_cond_dist ~loc id es - | UserDefined _ -> semantic_check_fn_normal ~is_cond_dist ~loc id es - -(* -- Ternary If ------------------------------------------------------------ *) - -let semantic_check_ternary_if loc (pe, te, fe) = - Validate.( - let err = - Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_ - fe.emeta.type_ - in - if pe.emeta.type_ = UInt then - match UnsizedType.common_type (te.emeta.type_, fe.emeta.type_) with - | Some type_ when not (UnsizedType.is_fun_type type_) -> - mk_typed_expression - ~expr:(TernaryIf (pe, te, fe)) - ~ad_level:(expr_ad_lub [pe; te; fe]) - ~type_ ~loc - |> ok - | Some _ | None -> error err - else error err) - -(* -- Binary (Infix) Operators ---------------------------------------------- *) - -let semantic_check_binop loc op (le, re) = - Validate.( - let err = - Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_ - in - [le; re] |> List.map ~f:arg_type - |> Stan_math_signatures.operator_stan_math_return_type op - |> Option.value_map ~default:(error err) ~f:(function - | ReturnType type_ -> - mk_typed_expression - ~expr:(BinOp (le, op, re)) - ~ad_level:(expr_ad_lub [le; re]) - ~type_ ~loc - |> ok - | Void -> error err )) - -let to_exn v = - v |> Validate.to_result - |> Result.map_error ~f:Fmt.(to_to_string @@ list ~sep:cut Semantic_error.pp) - |> Result.ok_or_failwith - -let semantic_check_binop_exn loc op (le, re) = - semantic_check_binop loc op (le, re) |> to_exn - -(* -- Prefix Operators ------------------------------------------------------ *) - -let semantic_check_prefixop loc op e = - Validate.( - let err = Semantic_error.illtyped_prefix_op loc op e.emeta.type_ in - Stan_math_signatures.operator_stan_math_return_type op [arg_type e] - |> Option.value_map ~default:(error err) ~f:(function - | ReturnType type_ -> - mk_typed_expression - ~expr:(PrefixOp (op, e)) - ~ad_level:(expr_ad_lub [e]) - ~type_ ~loc - |> ok - | Void -> error err )) - -(* -- Postfix operators ----------------------------------------------------- *) - -let semantic_check_postfixop loc op e = - Validate.( - let err = Semantic_error.illtyped_postfix_op loc op e.emeta.type_ in - Stan_math_signatures.operator_stan_math_return_type op [arg_type e] - |> Option.value_map ~default:(error err) ~f:(function - | ReturnType type_ -> - mk_typed_expression - ~expr:(PostfixOp (e, op)) - ~ad_level:(expr_ad_lub [e]) - ~type_ ~loc - |> ok - | Void -> error err )) - -(* -- Variables ------------------------------------------------------------- *) -let semantic_check_variable cf loc id = - Validate.( - match Symbol_table.look vm (Utils.stdlib_distribution_name id.name) with - | None when not (Stan_math_signatures.is_stan_math_function_name id.name) - -> - Semantic_error.ident_not_in_scope loc id.name |> error - | None -> - mk_typed_expression ~expr:(Variable id) - ~ad_level: - (calculate_autodifftype cf MathLibrary UMathLibraryFunction) - ~type_:UMathLibraryFunction ~loc - |> ok - | Some ((Param | TParam | GQuant), _) when cf.in_toplevel_decl -> - Semantic_error.non_data_variable_size_decl loc |> error - | Some _ - when Utils.is_unnormalized_distribution id.name - && not - ( (cf.in_fun_def && (cf.in_udf_dist_def || cf.in_lp_fun_def)) - || cf.current_block = Model ) -> - Semantic_error.invalid_unnormalized_fn loc |> error - | Some (originblock, UFun (args, rt, FnLpdf _, mem_pattern)) -> - let type_ = - UnsizedType.UFun - (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) - in - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf originblock type_) - ~type_ ~loc - |> ok - | Some (originblock, type_) -> - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf originblock type_) - ~type_ ~loc - |> ok) - -(* -- Conditioned Distribution Application ---------------------------------- *) - -let semantic_check_conddist_name ~loc id = - Validate.( - if - List.exists - ~f:(fun x -> String.is_suffix id.name ~suffix:x) - Utils.conditioning_suffices - then ok () - else Semantic_error.conditional_notation_not_allowed loc |> error) - -(* -- Array Expressions ----------------------------------------------------- *) - -let check_consistent_types ad_level type_ es = - let f state e = - match state with - | Error e -> Error e - | Ok (ad, ty) -> ( - let ad = - if UnsizedType.autodifftype_can_convert e.emeta.ad_level ad then - e.emeta.ad_level - else ad - in - match UnsizedType.common_type (ty, e.emeta.type_) with - | Some ty -> Ok (ad, ty) - | None -> Error (ty, e.emeta) ) - in - List.fold ~init:(Ok (ad_level, type_)) ~f es - -let semantic_check_array_expr ~loc es = - Validate.( - match es with - | [] -> Semantic_error.empty_array loc |> error - | {emeta= {ad_level; type_; _}; _} :: elements -> ( - match check_consistent_types ad_level type_ elements with - | Error (ty, meta) -> - Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error - | Ok (ad_level, type_) -> - let type_ = UnsizedType.UArray type_ in - mk_typed_expression ~expr:(ArrayExpr es) ~ad_level ~type_ ~loc |> ok - )) - -(* -- Row Vector Expresssion ------------------------------------------------ *) - -let semantic_check_rowvector ~loc es = - Validate.( - match es with - | {emeta= {ad_level; type_= UnsizedType.URowVector; _}; _} :: elements -> ( - match check_consistent_types ad_level URowVector elements with - | Ok (ad_level, _) -> - mk_typed_expression ~expr:(RowVectorExpr es) ~ad_level ~type_:UMatrix - ~loc - |> ok - | Error (_, meta) -> - Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) - | _ -> ( - match check_consistent_types DataOnly UReal es with - | Ok (ad_level, _) -> - mk_typed_expression ~expr:(RowVectorExpr es) ~ad_level - ~type_:URowVector ~loc - |> ok - | Error (_, meta) -> - Semantic_error.invalid_row_vector_types meta.loc meta.type_ |> error - )) - -(* -- Indexed Expressions --------------------------------------------------- *) -let tuple2 a b = (a, b) -let tuple3 a b c = (a, b, c) - -let indexing_type idx = - match idx with - | Single {emeta= {type_= UnsizedType.UInt; _}; _} -> `Single - | _ -> `Multi - -let inferred_unsizedtype_of_indexed ~loc ut indices = - let rec aux type_ idcs = - match (type_, idcs) with - | _, [] -> Validate.ok type_ - | UnsizedType.UArray type_, `Single :: tl -> aux type_ tl - | UArray type_, `Multi :: tl -> - aux type_ tl |> Validate.map ~f:(fun t -> UnsizedType.UArray t) - | (UVector | URowVector), [`Single] | UMatrix, [`Single; `Single] -> - Validate.ok UnsizedType.UReal - | (UVector | URowVector | UMatrix), [`Multi] | UMatrix, [`Multi; `Multi] -> - Validate.ok type_ - | UMatrix, ([`Single] | [`Single; `Multi]) -> - Validate.ok UnsizedType.URowVector - | UMatrix, [`Multi; `Single] -> Validate.ok UnsizedType.UVector - | UMatrix, _ :: _ :: _ :: _ - |(UVector | URowVector), _ :: _ :: _ - |(UInt | UReal | UComplex | UFun _ | UMathLibraryFunction), _ :: _ -> - Semantic_error.not_indexable loc ut (List.length indices) - |> Validate.error - in - aux ut (List.map ~f:indexing_type indices) - -let inferred_unsizedtype_of_indexed_exn ~loc ut indices = - inferred_unsizedtype_of_indexed ~loc ut indices |> to_exn - -let inferred_ad_type_of_indexed at uindices = - UnsizedType.lub_ad_type - ( at - :: List.map - ~f:(function - | All -> UnsizedType.DataOnly - | Single ue1 | Upfrom ue1 | Downfrom ue1 -> - UnsizedType.lub_ad_type [at; ue1.emeta.ad_level] - | Between (ue1, ue2) -> - UnsizedType.lub_ad_type - [at; ue1.emeta.ad_level; ue2.emeta.ad_level]) - uindices ) - -let rec semantic_check_indexed ~loc ~cf e indices = - Validate.( - indices - |> List.map ~f:(semantic_check_index cf) - |> sequence - |> liftA2 tuple2 (semantic_check_expression cf e) - >>= fun (ue, uindices) -> - let at = inferred_ad_type_of_indexed ue.emeta.ad_level uindices in - uindices - |> inferred_unsizedtype_of_indexed ~loc ue.emeta.type_ - |> map ~f:(fun ut -> - mk_typed_expression - ~expr:(Indexed (ue, uindices)) - ~ad_level:at ~type_:ut ~loc )) - -and semantic_check_index cf = function - | All -> Validate.ok All - (* Check that indexes have int (container) type *) - | Single e -> - Validate.( - semantic_check_expression cf e - >>= fun ue -> - if has_int_type ue || has_int_array_type ue then ok @@ Single ue - else - Semantic_error.int_intarray_or_range_expected ue.emeta.loc - ue.emeta.type_ - |> error) - | Upfrom e -> - semantic_check_expression_of_int_type cf e "Range bound" - |> Validate.map ~f:(fun e -> Upfrom e) - | Downfrom e -> - semantic_check_expression_of_int_type cf e "Range bound" - |> Validate.map ~f:(fun e -> Downfrom e) - | Between (e1, e2) -> - let le = semantic_check_expression_of_int_type cf e1 "Range bound" - and ue = semantic_check_expression_of_int_type cf e2 "Range bound" in - Validate.liftA2 (fun l u -> Between (l, u)) le ue - -(* -- Top-level expressions ------------------------------------------------- *) -and semantic_check_expression cf ({emeta; expr} : Ast.untyped_expression) : - Ast.typed_expression Validate.t = - match expr with - | TernaryIf (e1, e2, e3) -> - let pe = semantic_check_expression cf e1 - and te = semantic_check_expression cf e2 - and fe = semantic_check_expression cf e3 in - Validate.(liftA3 tuple3 pe te fe >>= semantic_check_ternary_if emeta.loc) - | BinOp (e1, op, e2) -> - let le = semantic_check_expression cf e1 - and re = semantic_check_expression cf e2 - and warn_int_division (x, y) = - match (x.emeta.type_, y.emeta.type_, op) with - | UInt, UInt, Divide -> - let hint ppf () = - match (x.expr, y.expr) with - | IntNumeral x, _ -> - Fmt.pf ppf "%s.0 / %a" x Pretty_printing.pp_typed_expression - y - | _, Ast.IntNumeral y -> - Fmt.pf ppf "%a / %s.0" Pretty_printing.pp_typed_expression x - y - | _ -> - Fmt.pf ppf "%a * 1.0 / %a" - Pretty_printing.pp_typed_expression x - Pretty_printing.pp_typed_expression y - in - let s = - Fmt.strf - "@[@[Found int division:@]@ @[%a@]@,@[%a@]@ @[%a@]@,@[%a@]@]" - Pretty_printing.pp_expression {expr; emeta} Fmt.text - "Values will be rounded towards zero. If rounding is not \ - desired you can write the division as" - hint () Fmt.text - "If rounding is intended please use the integer division \ - operator %/%." - in - add_warning x.emeta.loc s ; (x, y) - | _ -> (x, y) - in - Validate.( - liftA2 tuple2 le re |> map ~f:warn_int_division - |> apply_const (semantic_check_operator op) - >>= semantic_check_binop emeta.loc op) - | PrefixOp (op, e) -> - Validate.( - semantic_check_expression cf e - |> apply_const (semantic_check_operator op) - >>= semantic_check_prefixop emeta.loc op) - | PostfixOp (e, op) -> - Validate.( - semantic_check_expression cf e - |> apply_const (semantic_check_operator op) - >>= semantic_check_postfixop emeta.loc op) - | Variable id -> - semantic_check_variable cf emeta.loc id - |> Validate.apply_const (semantic_check_identifier id) - | IntNumeral s -> ( - match float_of_string_opt s with - | Some i when i < 2_147_483_648.0 -> - mk_typed_expression ~expr:(IntNumeral s) ~ad_level:DataOnly ~type_:UInt - ~loc:emeta.loc - |> Validate.ok - | _ -> Semantic_error.bad_int_literal emeta.loc |> Validate.error ) - | RealNumeral s -> - mk_typed_expression ~expr:(RealNumeral s) ~ad_level:DataOnly ~type_:UReal - ~loc:emeta.loc - |> Validate.ok - | ImagNumeral s -> - mk_typed_expression ~expr:(ImagNumeral s) ~ad_level:DataOnly - ~type_:UComplex ~loc:emeta.loc - |> Validate.ok - | FunApp ((), id, es) -> - semantic_check_funapp ~is_cond_dist:false id es cf emeta - | CondDistApp ((), id, es) -> - semantic_check_funapp ~is_cond_dist:true id es cf emeta - | GetLP -> - (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) - if - not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then - Semantic_error.target_plusequals_outisde_model_or_logprob emeta.loc - |> Validate.error - else - mk_typed_expression ~expr:GetLP - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) - ~type_:UReal ~loc:emeta.loc - |> Validate.ok - | GetTarget -> - (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) - if - not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then - Semantic_error.target_plusequals_outisde_model_or_logprob emeta.loc - |> Validate.error - else - mk_typed_expression ~expr:GetTarget - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) - ~type_:UReal ~loc:emeta.loc - |> Validate.ok - | ArrayExpr es -> - Validate.( - es - |> List.map ~f:(semantic_check_expression cf) - |> sequence - >>= fun ues -> semantic_check_array_expr ~loc:emeta.loc ues) - | RowVectorExpr es -> - Validate.( - es - |> List.map ~f:(semantic_check_expression cf) - |> sequence - >>= semantic_check_rowvector ~loc:emeta.loc) - | Paren e -> - semantic_check_expression cf e - |> Validate.map ~f:(fun ue -> - mk_typed_expression ~expr:(Paren ue) ~ad_level:ue.emeta.ad_level - ~type_:ue.emeta.type_ ~loc:emeta.loc ) - | Indexed (e, indices) -> semantic_check_indexed ~loc:emeta.loc ~cf e indices - -and semantic_check_funapp ~is_cond_dist id es cf emeta = - let name_check = - if is_cond_dist then semantic_check_conddist_name - else semantic_check_fn_conditioning - in - Validate.( - es - |> List.map ~f:(semantic_check_expression cf) - |> sequence - >>= fun ues -> - semantic_check_fn ~is_cond_dist ~loc:emeta.loc id ues - |> apply_const (semantic_check_identifier id) - |> apply_const (name_check ~loc:emeta.loc id) - |> apply_const (semantic_check_fn_target_plus_equals cf ~loc:emeta.loc id) - |> apply_const (semantic_check_fn_rng cf ~loc:emeta.loc id) - |> apply_const (semantic_check_unnormalized cf ~loc:emeta.loc id)) - -and semantic_check_expression_of_int_type cf e name = - Validate.( - semantic_check_expression cf e - >>= fun ue -> - if has_int_type ue then ok ue - else Semantic_error.int_expected ue.emeta.loc name ue.emeta.type_ |> error) - -and semantic_check_expression_of_int_or_real_type cf e name = - Validate.( - semantic_check_expression cf e - >>= fun ue -> - if has_int_or_real_type ue then ok ue - else - Semantic_error.int_or_real_expected ue.emeta.loc name ue.emeta.type_ - |> error) - -let semantic_check_expression_of_scalar_or_type cf t e name = - Validate.( - semantic_check_expression cf e - >>= fun ue -> - if UnsizedType.is_scalar_type ue.emeta.type_ || ue.emeta.type_ = t then - ok ue - else - Semantic_error.scalar_or_type_expected ue.emeta.loc name t ue.emeta.type_ - |> error) - -(* -- Sized Types ----------------------------------------------------------- *) -let rec semantic_check_sizedtype cf = function - | SizedType.SInt -> Validate.ok SizedType.SInt - | SReal -> Validate.ok SizedType.SReal - | SComplex -> Validate.ok SizedType.SComplex - | SVector (mem_pattern, e) -> - semantic_check_expression_of_int_type cf e "Vector sizes" - |> Validate.map ~f:(fun ue -> SizedType.SVector (mem_pattern, ue)) - | SRowVector (mem_pattern, e) -> - semantic_check_expression_of_int_type cf e "Row vector sizes" - |> Validate.map ~f:(fun ue -> SizedType.SRowVector (mem_pattern, ue)) - | SMatrix (mem_pattern, e1, e2) -> - let ue1 = semantic_check_expression_of_int_type cf e1 "Matrix sizes" - and ue2 = semantic_check_expression_of_int_type cf e2 "Matrix sizes" in - Validate.liftA2 - (fun ue1 ue2 -> SizedType.SMatrix (mem_pattern, ue1, ue2)) - ue1 ue2 - | SArray (st, e) -> - let ust = semantic_check_sizedtype cf st - and ue = semantic_check_expression_of_int_type cf e "Array sizes" in - Validate.liftA2 (fun ust ue -> SizedType.SArray (ust, ue)) ust ue - -(* -- Transformations ------------------------------------------------------- *) -let semantic_check_transformation cf ut = function - | Transformation.Identity -> Validate.ok Transformation.Identity - | Lower e -> - semantic_check_expression_of_scalar_or_type cf ut e "Lower bound" - |> Validate.map ~f:(fun ue -> Transformation.Lower ue) - | Upper e -> - semantic_check_expression_of_scalar_or_type cf ut e "Upper bound" - |> Validate.map ~f:(fun ue -> Transformation.Upper ue) - | LowerUpper (e1, e2) -> - let ue1 = - semantic_check_expression_of_scalar_or_type cf ut e1 "Lower bound" - and ue2 = - semantic_check_expression_of_scalar_or_type cf ut e2 "Upper bound" - in - Validate.liftA2 - (fun ue1 ue2 -> Transformation.LowerUpper (ue1, ue2)) - ue1 ue2 - | Offset e -> - semantic_check_expression_of_scalar_or_type cf ut e "Offset" - |> Validate.map ~f:(fun ue -> Transformation.Offset ue) - | Multiplier e -> - semantic_check_expression_of_scalar_or_type cf ut e "Multiplier" - |> Validate.map ~f:(fun ue -> Transformation.Multiplier ue) - | OffsetMultiplier (e1, e2) -> - let ue1 = semantic_check_expression_of_scalar_or_type cf ut e1 "Offset" - and ue2 = - semantic_check_expression_of_scalar_or_type cf ut e2 "Multiplier" - in - Validate.liftA2 - (fun ue1 ue2 -> Transformation.OffsetMultiplier (ue1, ue2)) - ue1 ue2 - | Ordered -> Validate.ok Transformation.Ordered - | PositiveOrdered -> Validate.ok Transformation.PositiveOrdered - | Simplex -> Validate.ok Transformation.Simplex - | UnitVector -> Validate.ok Transformation.UnitVector - | CholeskyCorr -> Validate.ok Transformation.CholeskyCorr - | CholeskyCov -> Validate.ok Transformation.CholeskyCov - | Correlation -> Validate.ok Transformation.Correlation - | Covariance -> Validate.ok Transformation.Covariance - -(* -- Printables ------------------------------------------------------------ *) - -let semantic_check_printable cf = function - | PString s -> Validate.ok @@ PString s - (* Print/reject expressions cannot be of function type. *) - | PExpr e -> ( - Validate.( - semantic_check_expression cf e - >>= fun ue -> - match ue.emeta.type_ with - | UFun _ | UMathLibraryFunction -> - Semantic_error.not_printable ue.emeta.loc |> error - | _ -> ok @@ PExpr ue) ) - -(* -- Truncations ----------------------------------------------------------- *) - -let semantic_check_truncation cf = function - | NoTruncate -> Validate.ok NoTruncate - | TruncateUpFrom e -> - semantic_check_expression_of_int_or_real_type cf e "Truncation bound" - |> Validate.map ~f:(fun ue -> TruncateUpFrom ue) - | TruncateDownFrom e -> - semantic_check_expression_of_int_or_real_type cf e "Truncation bound" - |> Validate.map ~f:(fun ue -> TruncateDownFrom ue) - | TruncateBetween (e1, e2) -> - let ue1 = - semantic_check_expression_of_int_or_real_type cf e1 "Truncation bound" - and ue2 = - semantic_check_expression_of_int_or_real_type cf e2 "Truncation bound" - in - Validate.liftA2 (fun ue1 ue2 -> TruncateBetween (ue1, ue2)) ue1 ue2 - -(* == Statements ============================================================ *) - -(* -- Non-returning function application ------------------------------------ *) - -let semantic_check_nrfn_target ~loc ~cf id = - Validate.( - if - String.is_suffix id.name ~suffix:"_lp" - && not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error - else ok ()) - -let semantic_check_nrfn_normal ~loc id es = - Validate.( - match Symbol_table.look vm id.name with - | Some (_, UFun (listedtypes, Void, suffix, _)) -> ( - match - SignatureMismatch.check_compatible_arguments_mod_conv listedtypes - (get_arg_types es) - with - | None -> - mk_typed_statement - ~stmt:(NRFunApp (UserDefined suffix, id, es)) - ~return_type:NoReturnType ~loc - |> ok - | Some x -> - es - |> List.map ~f:type_of_expr_typed - |> Semantic_error.illtyped_userdefined_fn_app loc id.name listedtypes - Void x - |> error ) - | Some (_, UFun (_, ReturnType _, _, _)) -> - Semantic_error.nonreturning_fn_expected_returning_found loc id.name - |> error - | Some _ -> - Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name - |> error - | None -> - Semantic_error.nonreturning_fn_expected_undeclaredident_found loc - id.name - |> error) - -let semantic_check_nrfn_stan_math ~loc id es = - Validate.( - match - SignatureMismatch.stan_math_returntype id.name (get_arg_types es) - with - | Ok Void -> - mk_typed_statement - ~stmt:(NRFunApp (StanLib FnPlain, id, es)) - ~return_type:NoReturnType ~loc - |> ok - | Ok (ReturnType _) -> - Semantic_error.nonreturning_fn_expected_returning_found loc id.name - |> error - | Error x -> - es - |> List.map ~f:type_of_expr_typed - |> Semantic_error.illtyped_stanlib_fn_app loc id.name x - |> error) - -let semantic_check_nr_fnkind ~loc id es = - match fn_kind_from_application id es with - | StanLib _ -> semantic_check_nrfn_stan_math ~loc id es - | UserDefined _ -> semantic_check_nrfn_normal ~loc id es - -let semantic_check_nr_fn_app ~loc ~cf id es = - Validate.( - es - |> List.map ~f:(semantic_check_expression cf) - |> sequence - |> apply_const (semantic_check_identifier id) - |> apply_const (semantic_check_nrfn_target ~loc ~cf id) - >>= semantic_check_nr_fnkind ~loc id) - -(* -- Assignment ------------------------------------------------------------ *) - -let semantic_check_assignment_read_only ~loc id = - Validate.( - if Symbol_table.get_read_only vm id.name then - Semantic_error.cannot_assign_to_read_only loc id.name |> error - else ok ()) - -(* Variables from previous blocks are read-only. - In particular, data and parameters never assigned to -*) -let semantic_check_assignment_global ~loc ~cf ~block id = - Validate.( - if (not (Symbol_table.is_global vm id.name)) || block = cf.current_block - then ok () - else Semantic_error.cannot_assign_to_global loc id.name |> error) - -let mk_assignment_from_indexed_expr assop lhs rhs = - Assignment - {assign_lhs= Ast.lvalue_of_expr lhs; assign_op= assop; assign_rhs= rhs} - -let semantic_check_assignment_operator ~loc assop lhs rhs = - let op = - match assop with - | Assign | ArrowAssign -> Operator.Equals - | OperatorAssign op -> op - in - Validate.( - let err = - Semantic_error.illtyped_assignment loc op lhs.emeta.type_ rhs.emeta.type_ - in - match assop with - | Assign | ArrowAssign -> - if - UnsizedType.check_of_same_type_mod_array_conv "" lhs.emeta.type_ - rhs.emeta.type_ - then - mk_typed_statement ~return_type:NoReturnType ~loc - ~stmt:(mk_assignment_from_indexed_expr assop lhs rhs) - |> ok - else error err - | OperatorAssign op -> - List.map ~f:arg_type [lhs; rhs] - |> Stan_math_signatures.assignmentoperator_stan_math_return_type op - |> Option.value_map ~default:(error err) ~f:(function - | ReturnType _ -> error err - | Void -> - mk_typed_statement ~return_type:NoReturnType ~loc - ~stmt:(mk_assignment_from_indexed_expr assop lhs rhs) - |> ok )) - -let semantic_check_assignment ~loc ~cf assign_lhs assign_op assign_rhs = - let assign_id = Ast.id_of_lvalue assign_lhs in - let lhs = expr_of_lvalue assign_lhs |> semantic_check_expression cf - and assop = semantic_check_assignmentoperator assign_op - and rhs = semantic_check_expression cf assign_rhs - and block = - Symbol_table.look vm assign_id.name - |> Option.map ~f:(fun (block, _) -> Validate.ok block) - |> Option.value - ~default: - ( if Stan_math_signatures.is_stan_math_function_name assign_id.name - then Validate.ok MathLibrary - else - Validate.error - @@ Semantic_error.ident_not_in_scope loc assign_id.name ) - in - Validate.( - liftA2 tuple2 (liftA3 tuple3 lhs assop rhs) block - >>= fun ((lhs, assop, rhs), block) -> - semantic_check_assignment_operator ~loc assop lhs rhs - |> apply_const (semantic_check_assignment_global ~loc ~cf ~block assign_id) - |> apply_const (semantic_check_assignment_read_only ~loc assign_id)) - -(* -- Target plus-equals / Increment log-prob ------------------------------- *) - -let semantic_check_target_pe_expr_type ~loc e = - match e.emeta.type_ with - | UFun _ | UMathLibraryFunction -> - Semantic_error.int_or_real_container_expected loc e.emeta.type_ - |> Validate.error - | _ -> Validate.ok () - -let semantic_check_target_pe_usage ~loc ~cf = - if cf.in_lp_fun_def || cf.current_block = Model then Validate.ok () - else - Semantic_error.target_plusequals_outisde_model_or_logprob loc - |> Validate.error - -let semantic_check_target_pe ~loc ~cf e = - Validate.( - semantic_check_expression cf e - |> apply_const (semantic_check_target_pe_usage ~loc ~cf) - >>= fun ue -> - semantic_check_target_pe_expr_type ~loc ue - |> map ~f:(fun _ -> - mk_typed_statement ~stmt:(TargetPE ue) ~return_type:NoReturnType - ~loc )) - -let semantic_check_incr_logprob ~loc ~cf e = - Validate.( - semantic_check_expression cf e - |> apply_const (semantic_check_target_pe_usage ~loc ~cf) - >>= fun ue -> - semantic_check_target_pe_expr_type ~loc ue - |> map ~f:(fun _ -> - mk_typed_statement ~stmt:(IncrementLogProb ue) - ~return_type:NoReturnType ~loc )) - -(* -- Tilde (Sampling notation) --------------------------------------------- *) - -let semantic_check_sampling_pdf_pmf id = - Validate.( - if - String.( - is_suffix id.name ~suffix:"_lpdf" - || is_suffix id.name ~suffix:"_lpmf" - || is_suffix id.name ~suffix:"_lupdf" - || is_suffix id.name ~suffix:"_lupmf") - then error @@ Semantic_error.invalid_sampling_pdf_or_pmf id.id_loc - else ok ()) - -let semantic_check_sampling_cdf_ccdf ~loc id = - Validate.( - if - String.( - is_suffix id.name ~suffix:"_cdf" || is_suffix id.name ~suffix:"_ccdf") - then error @@ Semantic_error.invalid_sampling_cdf_or_ccdf loc id.name - else ok ()) - -(* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) -let semantic_check_valid_sampling_pos ~loc ~cf = - Validate.( - if not (cf.in_lp_fun_def || cf.current_block = Model) then - error @@ Semantic_error.target_plusequals_outisde_model_or_logprob loc - else ok ()) - -let semantic_check_sampling_distribution ~loc id arguments = - let name = id.name - and argumenttypes = List.map ~f:arg_type arguments - and is_real_rt = function - | UnsizedType.ReturnType UReal -> true - | _ -> false - in - let is_name_w_suffix_sampling_dist suffix = - Stan_math_signatures.stan_math_returntype (name ^ suffix) argumenttypes - |> Option.value_map ~default:false ~f:is_real_rt - in - let is_sampling_dist_in_math = - List.exists ~f:is_name_w_suffix_sampling_dist - (Utils.distribution_suffices @ Utils.unnormalized_suffices) - && name <> "binomial_coefficient" - && name <> "multiply" - in - let is_name_w_suffix_udf_sampling_dist suffix = - match Symbol_table.look vm (name ^ suffix) with - | Some (Functions, UFun (listedtypes, ReturnType UReal, FnLpdf _, _)) -> - UnsizedType.check_compatible_arguments_mod_conv name listedtypes - argumenttypes - | _ -> false - in - let is_udf_sampling_dist = - List.exists ~f:is_name_w_suffix_udf_sampling_dist - (Utils.distribution_suffices @ Utils.unnormalized_suffices) - in - Validate.( - if is_sampling_dist_in_math || is_udf_sampling_dist then ok () - else error @@ Semantic_error.invalid_sampling_no_such_dist loc name) - -let cumulative_density_is_defined id arguments = - let name = id.name - and argumenttypes = List.map ~f:arg_type arguments - and is_real_rt = function - | UnsizedType.ReturnType UReal -> true - | _ -> false - in - let is_real_rt_for_suffix suffix = - Stan_math_signatures.stan_math_returntype (name ^ suffix) argumenttypes - |> Option.value_map ~default:false ~f:is_real_rt - and valid_arg_types_for_suffix suffix = - match Symbol_table.look vm (name ^ suffix) with - | Some (Functions, UFun (listedtypes, ReturnType UReal, FnPlain, _)) -> - UnsizedType.check_compatible_arguments_mod_conv name listedtypes - argumenttypes - | _ -> false - in - ( is_real_rt_for_suffix "_lcdf" - || valid_arg_types_for_suffix "_lcdf" - || is_real_rt_for_suffix "_cdf_log" - || valid_arg_types_for_suffix "_cdf_log" ) - && ( is_real_rt_for_suffix "_lccdf" - || valid_arg_types_for_suffix "_lccdf" - || is_real_rt_for_suffix "_ccdf_log" - || valid_arg_types_for_suffix "_ccdf_log" ) - -let can_truncate_distribution ~loc (arg : typed_expression) = function - | NoTruncate -> Validate.ok () - | _ -> - if UnsizedType.is_scalar_type arg.emeta.type_ then Validate.ok () - else Validate.error @@ Semantic_error.multivariate_truncation loc - -let semantic_check_sampling_cdf_defined ~loc id truncation args = - Validate.( - match truncation with - | NoTruncate -> ok () - | TruncateUpFrom e when cumulative_density_is_defined id (e :: args) -> - ok () - | TruncateDownFrom e when cumulative_density_is_defined id (e :: args) -> - ok () - | TruncateBetween (e1, e2) - when cumulative_density_is_defined id (e1 :: args) - && cumulative_density_is_defined id (e2 :: args) -> - ok () - | _ -> error @@ Semantic_error.invalid_truncation_cdf_or_ccdf loc) - -let semantic_check_tilde ~loc ~cf distribution truncation arg args = - Validate.( - let ue = semantic_check_expression cf arg - and ues = List.map ~f:(semantic_check_expression cf) args |> sequence - and ut = semantic_check_truncation cf truncation in - liftA3 tuple3 ut ue ues - |> apply_const (semantic_check_identifier distribution) - |> apply_const (semantic_check_sampling_pdf_pmf distribution) - |> apply_const (semantic_check_valid_sampling_pos ~loc ~cf) - |> apply_const (semantic_check_sampling_cdf_ccdf ~loc distribution) - >>= fun (truncation, arg, args) -> - semantic_check_sampling_distribution ~loc distribution (arg :: args) - |> apply_const - (semantic_check_sampling_cdf_defined ~loc distribution truncation args) - |> apply_const (can_truncate_distribution ~loc arg truncation) - |> map ~f:(fun _ -> - let stmt = Tilde {arg; distribution; args; truncation} in - mk_typed_statement ~stmt ~loc ~return_type:NoReturnType )) - -(* -- Break ----------------------------------------------------------------- *) -(* Break and continue only occur in loops. *) -let semantic_check_break ~loc ~cf = - Validate.( - if cf.loop_depth = 0 then Semantic_error.break_outside_loop loc |> error - else mk_typed_statement ~stmt:Break ~return_type:NoReturnType ~loc |> ok) - -(* -- Continue -------------------------------------------------------------- *) - -let semantic_check_continue ~loc ~cf = - Validate.( - (* Break and continue only occur in loops. *) - if cf.loop_depth = 0 then Semantic_error.continue_outside_loop loc |> error - else mk_typed_statement ~stmt:Continue ~return_type:NoReturnType ~loc |> ok) - -(* -- Return ---------------------------------------------------------------- *) - -(** No returns outside of function definitions - In case of void function, no return statements anywhere -*) -let semantic_check_return ~loc ~cf e = - Validate.( - if not cf.in_returning_fun_def then - Semantic_error.expression_return_outside_returning_fn loc |> error - else - semantic_check_expression cf e - |> map ~f:(fun ue -> - mk_typed_statement ~stmt:(Return ue) - ~return_type:(Complete (ReturnType ue.emeta.type_)) ~loc )) - -(* -- Return `void` --------------------------------------------------------- *) - -let semantic_check_returnvoid ~loc ~cf = - Validate.( - if (not cf.in_fun_def) || cf.in_returning_fun_def then - Semantic_error.void_ouside_nonreturning_fn loc |> error - else - mk_typed_statement ~stmt:ReturnVoid ~return_type:(Complete Void) ~loc - |> ok) - -(* -- Print ----------------------------------------------------------------- *) - -let semantic_check_print ~loc ~cf ps = - Validate.( - ps - |> List.map ~f:(semantic_check_printable cf) - |> sequence - |> map ~f:(fun ups -> - mk_typed_statement ~stmt:(Print ups) ~return_type:NoReturnType ~loc - )) - -(* -- Reject ---------------------------------------------------------------- *) - -let semantic_check_reject ~loc ~cf ps = - Validate.( - ps - |> List.map ~f:(semantic_check_printable cf) - |> sequence - |> map ~f:(fun ups -> - mk_typed_statement ~stmt:(Reject ups) ~return_type:AnyReturnType - ~loc )) - -(* -- Skip ------------------------------------------------------------------ *) - -let semantic_check_skip ~loc = - mk_typed_statement ~stmt:Skip ~return_type:NoReturnType ~loc |> Validate.ok - -(* -- If-Then-Else ---------------------------------------------------------- *) - -let try_compute_ifthenelse_statement_returntype loc srt1 srt2 = - match (srt1, srt2) with - | Complete rt1, Complete rt2 -> - lub_rt loc rt1 rt2 |> Validate.map ~f:(fun t -> Complete t) - | Incomplete rt1, Incomplete rt2 - |Complete rt1, Incomplete rt2 - |Incomplete rt1, Complete rt2 -> - lub_rt loc rt1 rt2 |> Validate.map ~f:(fun t -> Incomplete t) - | AnyReturnType, NoReturnType - |NoReturnType, AnyReturnType - |NoReturnType, NoReturnType -> - Validate.ok NoReturnType - | AnyReturnType, Incomplete rt - |Incomplete rt, AnyReturnType - |Complete rt, NoReturnType - |NoReturnType, Complete rt - |NoReturnType, Incomplete rt - |Incomplete rt, NoReturnType -> - Validate.ok @@ Incomplete rt - | Complete rt, AnyReturnType | AnyReturnType, Complete rt -> - Validate.ok @@ Complete rt - | AnyReturnType, AnyReturnType -> Validate.ok AnyReturnType - -let rec semantic_check_if_then_else ~loc ~cf pred_e s_true s_false_opt = - let us1 = semantic_check_statement cf s_true - and uos2 = - s_false_opt - |> Option.map ~f:(fun s -> - semantic_check_statement cf s |> Validate.map ~f:Option.some ) - |> Option.value ~default:(Validate.ok None) - and ue = - semantic_check_expression_of_int_or_real_type cf pred_e - "Condition in conditional" - in - Validate.( - liftA3 tuple3 ue us1 uos2 - >>= fun (ue, us1, uos2) -> - let stmt = IfThenElse (ue, us1, uos2) - and srt1 = us1.smeta.return_type - and srt2 = - uos2 - |> Option.map ~f:(fun s -> s.smeta.return_type) - |> Option.value ~default:NoReturnType - in - try_compute_ifthenelse_statement_returntype loc srt1 srt2 - |> map ~f:(fun return_type -> mk_typed_statement ~stmt ~return_type ~loc)) - -(* -- While Statements ------------------------------------------------------ *) -and semantic_check_while ~loc ~cf e s = - let us = semantic_check_statement {cf with loop_depth= cf.loop_depth + 1} s - and ue = - semantic_check_expression_of_int_or_real_type cf e - "Condition in while-loop" - in - Validate.liftA2 - (fun ue us -> - mk_typed_statement - ~stmt:(While (ue, us)) - ~return_type:us.smeta.return_type ~loc ) - ue us - -(* -- For Statements -------------------------------------------------------- *) -and semantic_check_loop_body ~cf loop_var loop_var_ty loop_body = - Symbol_table.begin_scope vm ; - let is_fresh_var = check_fresh_variable loop_var false in - Symbol_table.enter vm loop_var.name (cf.current_block, loop_var_ty) ; - (* Check that function args and loop identifiers are not modified in - function. (passed by const ref) *) - Symbol_table.set_read_only vm loop_var.name ; - let us = - semantic_check_statement {cf with loop_depth= cf.loop_depth + 1} loop_body - |> Validate.apply_const is_fresh_var - in - Symbol_table.end_scope vm ; us - -and semantic_check_for ~loc ~cf loop_var lower_bound_e upper_bound_e loop_body - = - let ue1 = - semantic_check_expression_of_int_type cf lower_bound_e - "Lower bound of for-loop" - and ue2 = - semantic_check_expression_of_int_type cf upper_bound_e - "Upper bound of for-loop" - in - Validate.( - liftA2 tuple2 ue1 ue2 - |> apply_const (semantic_check_identifier loop_var) - >>= fun (ue1, ue2) -> - semantic_check_loop_body ~cf loop_var UInt loop_body - |> map ~f:(fun us -> - mk_typed_statement - ~stmt: - (For - { loop_variable= loop_var - ; lower_bound= ue1 - ; upper_bound= ue2 - ; loop_body= us }) - ~return_type:us.smeta.return_type ~loc )) - -(* -- Foreach Statements ---------------------------------------------------- *) -and semantic_check_foreach_loop_identifier_type ~loc ty = - Validate.( - match ty with - | UnsizedType.UArray ut -> ok ut - | UVector | URowVector | UMatrix -> ok UnsizedType.UReal - | _ -> - Semantic_error.array_vector_rowvector_matrix_expected loc ty |> error) - -and semantic_check_foreach ~loc ~cf loop_var foreach_expr loop_body = - Validate.( - semantic_check_expression cf foreach_expr - |> apply_const (semantic_check_identifier loop_var) - >>= fun ue -> - semantic_check_foreach_loop_identifier_type ~loc:ue.emeta.loc - ue.emeta.type_ - >>= fun loop_var_ty -> - semantic_check_loop_body ~cf loop_var loop_var_ty loop_body - |> map ~f:(fun us -> - mk_typed_statement - ~stmt:(ForEach (loop_var, ue, us)) - ~return_type:us.smeta.return_type ~loc )) - -(* -- Blocks ---------------------------------------------------------------- *) -and stmt_is_escape {stmt; _} = - match stmt with - | Break | Continue | Reject _ | Return _ | ReturnVoid -> true - | _ -> false - -and list_until_escape xs = - let rec aux accu = function - | next :: next' :: _ when stmt_is_escape next' -> - List.rev (next' :: next :: accu) - | next :: rest -> aux (next :: accu) rest - | [] -> List.rev accu - in - aux [] xs - -and try_compute_block_statement_returntype loc srt1 srt2 = - match (srt1, srt2) with - | Complete rt1, Complete rt2 | Incomplete rt1, Complete rt2 -> - lub_rt loc rt1 rt2 |> Validate.map ~f:(fun t -> Complete t) - | Incomplete rt1, Incomplete rt2 | Complete rt1, Incomplete rt2 -> - lub_rt loc rt1 rt2 |> Validate.map ~f:(fun t -> Incomplete t) - | NoReturnType, NoReturnType -> Validate.ok NoReturnType - | AnyReturnType, Incomplete rt - |Complete rt, NoReturnType - |NoReturnType, Incomplete rt - |Incomplete rt, NoReturnType -> - Validate.ok @@ Incomplete rt - | NoReturnType, Complete rt - |Complete rt, AnyReturnType - |Incomplete rt, AnyReturnType - |AnyReturnType, Complete rt -> - Validate.ok @@ Complete rt - | AnyReturnType, NoReturnType - |NoReturnType, AnyReturnType - |AnyReturnType, AnyReturnType -> - Validate.ok AnyReturnType - -and semantic_check_block ~loc ~cf stmts = - Symbol_table.begin_scope vm ; - (* Any statements after a break or continue or return or reject - do not count for the return type. - *) - let validated_stmts = - List.map ~f:(semantic_check_statement cf) stmts |> Validate.sequence - in - Symbol_table.end_scope vm ; - Validate.( - validated_stmts - >>= fun xs -> - let return_ty = - xs |> list_until_escape - |> List.map ~f:(fun s -> s.smeta.return_type) - |> List.fold ~init:(ok NoReturnType) ~f:(fun accu x -> - accu >>= fun y -> try_compute_block_statement_returntype loc y x - ) - in - map return_ty ~f:(fun return_type -> - mk_typed_statement ~stmt:(Block xs) ~return_type ~loc )) - -and semantic_check_profile ~loc ~cf name stmts = - Symbol_table.begin_scope vm ; - (* Any statements after a break or continue or return or reject - do not count for the return type. - *) - let validated_stmts = - List.map ~f:(semantic_check_statement cf) stmts |> Validate.sequence - in - Symbol_table.end_scope vm ; - Validate.( - validated_stmts - >>= fun xs -> - let return_ty = - xs |> list_until_escape - |> List.map ~f:(fun s -> s.smeta.return_type) - |> List.fold ~init:(ok NoReturnType) ~f:(fun accu x -> - accu >>= fun y -> try_compute_block_statement_returntype loc y x - ) - in - map return_ty ~f:(fun return_type -> - mk_typed_statement ~stmt:(Profile (name, xs)) ~return_type ~loc )) - -(* -- Variable Declarations ------------------------------------------------- *) -and semantic_check_var_decl_bounds ~loc is_global sized_ty trans = - let is_real {emeta; _} = emeta.type_ = UReal in - let is_real_transformation = - match trans with - | Transformation.Lower e -> is_real e - | Upper e -> is_real e - | LowerUpper (e1, e2) -> is_real e1 || is_real e2 - | _ -> false - in - let is_transformation = - match trans with Transformation.Identity -> false | _ -> true - in - Validate.( - if is_global && sized_ty = SizedType.SInt && is_real_transformation then - Semantic_error.non_int_bounds loc |> error - else if - is_global - && SizedType.(inner_type sized_ty = SComplex) - && is_transformation - then Semantic_error.complex_transform loc |> error - else ok ()) - -and semantic_check_transformed_param_ty ~loc ~cf is_global unsized_ty = - Validate.( - if - is_global - && (cf.current_block = Param || cf.current_block = TParam) - && UnsizedType.contains_int unsized_ty - then Semantic_error.transformed_params_int loc |> error - else ok ()) - -and semantic_check_var_decl_initial_value ~loc ~cf id init_val_opt = - init_val_opt - |> Option.value_map ~default:(Validate.ok None) ~f:(fun e -> - let stmt = - Assignment - { assign_lhs= {lval= LVariable id; lmeta= {loc}} - ; assign_op= Assign - ; assign_rhs= e } - in - mk_untyped_statement ~loc ~stmt - |> semantic_check_statement cf - |> Validate.map ~f:(fun ts -> - match (ts.stmt, ts.smeta.return_type) with - | Assignment {assign_rhs= ue; _}, NoReturnType -> Some ue - | _ -> - let msg = - "semantic_check_var_decl: `Assignment` expected." - in - fatal_error ~msg () ) ) - -and semantic_check_var_decl ~loc ~cf sized_ty trans id init is_global = - let checked_stmt = - semantic_check_sizedtype {cf with in_toplevel_decl= is_global} sized_ty - in - Validate.( - let checked_trans = - checked_stmt - >>= fun ust -> - semantic_check_transformation cf (SizedType.to_unsized ust) trans - in - liftA2 tuple2 checked_stmt checked_trans - |> apply_const (semantic_check_identifier id) - |> apply_const (check_fresh_variable id false) - >>= fun (ust, utrans) -> - let ut = SizedType.to_unsized ust in - Symbol_table.enter vm id.name (cf.current_block, ut) ; - semantic_check_var_decl_initial_value ~loc ~cf id init - |> apply_const (semantic_check_var_decl_bounds ~loc is_global ust utrans) - |> apply_const (semantic_check_transformed_param_ty ~loc ~cf is_global ut) - |> map ~f:(fun uinit -> - let stmt = - VarDecl - { decl_type= Sized ust - ; transformation= utrans - ; identifier= id - ; initial_value= uinit - ; is_global } - in - mk_typed_statement ~stmt ~loc ~return_type:NoReturnType )) - -(* -- Function definitions -------------------------------------------------- *) -and semantic_check_fundef_overloaded ~loc id arg_tys rt = - Validate.( - (* User defined functions cannot be overloaded *) - if Symbol_table.check_is_unassigned vm id.name then - match Symbol_table.look vm id.name with - | Some (Functions, UFun (arg_tys', rt', _, _)) - when arg_tys' = arg_tys && rt' = rt -> - ok () - | _ -> - Symbol_table.look vm id.name - |> Option.map ~f:snd - |> Semantic_error.mismatched_fn_def_decl loc id.name - |> error - else check_fresh_variable id true) - -(** WARNING: side effecting *) -and semantic_check_fundef_decl ~loc id body = - Validate.( - match body with - | {stmt= Skip; _} -> - if Symbol_table.check_is_unassigned vm id.name then - error @@ Semantic_error.fn_decl_without_def loc - else - let () = Symbol_table.set_is_unassigned vm id.name in - ok () - | _ -> - Symbol_table.set_is_assigned vm id.name ; - ok ()) - -and semantic_check_fundef_dist_rt ~loc id return_ty = - Validate.( - let is_dist = - List.exists - ~f:(fun x -> String.is_suffix id.name ~suffix:x) - Utils.conditioning_suffices_w_log - in - if is_dist then - match return_ty with - | UnsizedType.ReturnType UReal -> ok () - | _ -> error @@ Semantic_error.non_real_prob_fn_def loc - else ok ()) - -and semantic_check_pdf_fundef_first_arg_ty ~loc id arg_tys = - Validate.( - (* TODO: I think these kind of functions belong with the type definition *) - let is_real_type = function - | UnsizedType.UReal | UVector | URowVector | UMatrix - |UArray UReal - |UArray UVector - |UArray URowVector - |UArray UMatrix -> - true - | _ -> false - in - if String.is_suffix id.name ~suffix:"_lpdf" then - List.hd arg_tys - |> Option.value_map - ~default: - (error @@ Semantic_error.prob_density_non_real_variate loc None) - ~f:(fun (_, rt) -> - if is_real_type rt then ok () - else - error - @@ Semantic_error.prob_density_non_real_variate loc (Some rt) ) - else ok ()) - -and semantic_check_pmf_fundef_first_arg_ty ~loc id arg_tys = - Validate.( - (* TODO: I think these kind of functions belong with the type definition *) - let is_int_type = function - | UnsizedType.UInt | UArray UInt -> true - | _ -> false - in - if String.is_suffix id.name ~suffix:"_lpmf" then - List.hd arg_tys - |> Option.value_map - ~default:(error @@ Semantic_error.prob_mass_non_int_variate loc None) - ~f:(fun (_, rt) -> - if is_int_type rt then ok () - else - error @@ Semantic_error.prob_mass_non_int_variate loc (Some rt) - ) - else ok ()) - -(* All function arguments are distinct *) -and semantic_check_fundef_distinct_arg_ids ~loc arg_names = - Validate.( - if dup_exists arg_names then - error @@ Semantic_error.duplicate_arg_names loc - else ok ()) - -(* Check that every trace through function body contains return statement of right type *) -and semantic_check_fundef_return_tys ~loc id return_type body = - Validate.( - if - Symbol_table.check_is_unassigned vm id.name - || check_of_compatible_return_type return_type body.smeta.return_type - then ok () - else error @@ Semantic_error.incompatible_return_types loc) - -and semantic_check_fundef ~loc ~cf return_ty id args body = - let uargs = - List.map args ~f:(fun (at, ut, id) -> - Validate.( - semantic_check_autodifftype at - |> apply_const (semantic_check_unsizedtype ut) - |> apply_const (semantic_check_identifier id) - |> map ~f:(fun at -> (at, ut, id))) ) - |> Validate.sequence - in - Validate.( - uargs - |> apply_const (semantic_check_identifier id) - |> apply_const (semantic_check_returntype return_ty) - >>= fun uargs -> - let urt = return_ty in - let uarg_types = List.map ~f:(fun (w, y, _) -> (w, y)) uargs in - let uarg_identifiers = List.map ~f:(fun (_, _, z) -> z) uargs in - let uarg_names = List.map ~f:(fun x -> x.name) uarg_identifiers in - semantic_check_fundef_overloaded ~loc id uarg_types urt - |> apply_const (semantic_check_fundef_decl ~loc id body) - >>= fun _ -> - (* WARNING: SIDE EFFECTING *) - Symbol_table.enter vm id.name - ( Functions - , UFun (uarg_types, urt, Fun_kind.suffix_from_name id.name, AoS) ) ; - (* Check that function args and loop identifiers are not modified in - function. (passed by const ref)*) - List.iter ~f:(Symbol_table.set_read_only vm) uarg_names ; - semantic_check_fundef_dist_rt ~loc id urt - |> apply_const (semantic_check_pdf_fundef_first_arg_ty ~loc id uarg_types) - |> apply_const (semantic_check_pmf_fundef_first_arg_ty ~loc id uarg_types) - >>= fun _ -> - (* WARNING: SIDE EFFECTING *) - Symbol_table.begin_scope vm ; - List.map ~f:(fun x -> check_fresh_variable x false) uarg_identifiers - |> sequence - |> apply_const (semantic_check_fundef_distinct_arg_ids ~loc uarg_names) - >>= fun _ -> - (* TODO: Bob was suggesting that function arguments must be allowed to - shadow user defined functions but not library functions. - Should we allow for that? - *) - (* We treat DataOnly arguments as if they are data and AutoDiffable arguments - as if they are parameters, for the purposes of type checking. - *) - (* WARNING: SIDE EFFECTING *) - let _ : unit Base.List.Or_unequal_lengths.t = - List.iter2 ~f:(Symbol_table.enter vm) uarg_names - (List.map - ~f:(function - | UnsizedType.DataOnly, ut -> (Data, ut) - | AutoDiffable, ut -> (Param, ut)) - uarg_types) - and context = - let is_udf_dist name = - List.exists - ~f:(fun suffix -> String.is_suffix name ~suffix) - Utils.distribution_suffices - in - { cf with - in_fun_def= true - ; in_rng_fun_def= String.is_suffix id.name ~suffix:"_rng" - ; in_lp_fun_def= String.is_suffix id.name ~suffix:"_lp" - ; in_udf_dist_def= is_udf_dist id.name - ; in_returning_fun_def= urt <> Void } - in - let body' = semantic_check_statement context body in - body' - >>= fun ub -> - semantic_check_fundef_return_tys ~loc id urt ub - |> map ~f:(fun _ -> - (* WARNING: SIDE EFFECTING *) - Symbol_table.end_scope vm ; - let stmt = - FunDef {returntype= urt; funname= id; arguments= uargs; body= ub} - in - mk_typed_statement ~return_type:NoReturnType ~loc ~stmt )) - -(* -- Top-level Statements -------------------------------------------------- *) -and semantic_check_statement cf (s : Ast.untyped_statement) : - Ast.typed_statement Validate.t = - let loc = s.smeta.loc in - match s.stmt with - | NRFunApp (_, id, es) -> semantic_check_nr_fn_app ~loc ~cf id es - | Assignment {assign_lhs; assign_op; assign_rhs} -> - semantic_check_assignment ~loc ~cf assign_lhs assign_op assign_rhs - | TargetPE e -> semantic_check_target_pe ~loc ~cf e - | IncrementLogProb e -> semantic_check_incr_logprob ~loc ~cf e - | Tilde {arg; distribution; args; truncation} -> - semantic_check_tilde ~loc ~cf distribution truncation arg args - | Break -> semantic_check_break ~loc ~cf - | Continue -> semantic_check_continue ~loc ~cf - | Return e -> semantic_check_return ~loc ~cf e - | ReturnVoid -> semantic_check_returnvoid ~loc ~cf - | Print ps -> semantic_check_print ~loc ~cf ps - | Reject ps -> semantic_check_reject ~loc ~cf ps - | Skip -> semantic_check_skip ~loc - | IfThenElse (e, s1, os2) -> semantic_check_if_then_else ~loc ~cf e s1 os2 - | While (e, s) -> semantic_check_while ~loc ~cf e s - | For {loop_variable; lower_bound; upper_bound; loop_body} -> - semantic_check_for ~loc ~cf loop_variable lower_bound upper_bound - loop_body - | ForEach (id, e, s) -> semantic_check_foreach ~loc ~cf id e s - | Block vdsl -> semantic_check_block ~loc ~cf vdsl - | Profile (name, vdsl) -> semantic_check_profile ~loc ~cf name vdsl - | VarDecl {decl_type= Unsized _; _} -> - raise_s [%message "Don't support unsized declarations yet."] - | VarDecl - { decl_type= Sized st - ; transformation - ; identifier - ; initial_value - ; is_global } -> - semantic_check_var_decl ~loc ~cf st transformation identifier - initial_value is_global - | FunDef {returntype; funname; arguments; body} -> - semantic_check_fundef ~loc ~cf returntype funname arguments body - -(* == Untyped programs ====================================================== *) - -let semantic_check_ostatements_in_block ~cf block stmts_opt = - let cf' = {cf with current_block= block} in - Option.value_map stmts_opt ~default:(Validate.ok None) - ~f:(fun {stmts; xloc} -> - (* I'm folding since I'm not sure if map is guaranteed to - respect the ordering of the list *) - List.fold ~init:[] stmts ~f:(fun accu stmt -> - let s = semantic_check_statement cf' stmt in - s :: accu ) - |> List.rev |> Validate.sequence - |> Validate.map ~f:(fun stmts -> Some {stmts; xloc}) ) - -let check_fun_def_body_in_block = function - | {stmt= FunDef {body= {stmt= Block _; _}; _}; _} - |{stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> - Validate.ok () - | {stmt= FunDef {body= {stmt= _; smeta}; _}; _} -> - Validate.error @@ Semantic_error.fn_decl_needs_block smeta.loc - | _ -> Validate.ok () - -let semantic_check_functions_have_defn function_block_stmts_opt = - Validate.( - if - Symbol_table.check_some_id_is_unassigned vm - && !check_that_all_functions_have_definition - then - match function_block_stmts_opt with - | Some {stmts= {smeta; _} :: _; _} -> - (* TODO: insert better location in the error *) - error @@ Semantic_error.fn_decl_without_def smeta.loc - | Some {stmts= []; _} | None -> - fatal_error ~msg:"semantic_check_functions_have_defn" () - else - match function_block_stmts_opt with - | Some {stmts= []; _} | None -> ok () - | Some {stmts= ls; _} -> - List.map ~f:check_fun_def_body_in_block ls - |> sequence - |> map ~f:(List.iter ~f:Fn.id)) - -(* The actual semantic checks for all AST nodes! *) -let semantic_check_program - { functionblock= fb - ; datablock= db - ; transformeddatablock= tdb - ; parametersblock= pb - ; transformedparametersblock= tpb - ; modelblock= mb - ; generatedquantitiesblock= gb - ; comments } = - (* NB: We always want to make sure we start with an empty symbol table, in - case we are processing multiple files in one run. *) - unsafe_clear_symbol_table vm ; - let cf = - { current_block= Functions - ; in_toplevel_decl= false - ; in_fun_def= false - ; in_returning_fun_def= false - ; in_rng_fun_def= false - ; in_lp_fun_def= false - ; in_udf_dist_def= false - ; loop_depth= 0 } - in - let ufb = - Validate.( - semantic_check_ostatements_in_block ~cf Functions fb - >>= fun xs -> - semantic_check_functions_have_defn xs |> map ~f:(fun () -> xs)) - in - let udb = semantic_check_ostatements_in_block ~cf Data db in - let utdb = semantic_check_ostatements_in_block ~cf TData tdb in - let upb = semantic_check_ostatements_in_block ~cf Param pb in - let utpb = semantic_check_ostatements_in_block ~cf TParam tpb in - (* Model top level variables only assigned and read in model *) - Symbol_table.begin_scope vm ; - let umb = semantic_check_ostatements_in_block ~cf Model mb in - Symbol_table.end_scope vm ; - let ugb = semantic_check_ostatements_in_block ~cf GQuant gb in - let mk_typed_prog ufb udb utdb upb utpb umb ugb : Ast.typed_program = - { functionblock= ufb - ; datablock= udb - ; transformeddatablock= utdb - ; parametersblock= upb - ; transformedparametersblock= utpb - ; modelblock= umb - ; generatedquantitiesblock= ugb - ; comments } - in - let apply_to x f = Validate.apply ~f x in - let check_correctness_invariant (decorated_ast : typed_program) : - typed_program = - if - compare_untyped_program - { functionblock= fb - ; datablock= db - ; transformeddatablock= tdb - ; parametersblock= pb - ; transformedparametersblock= tpb - ; modelblock= mb - ; generatedquantitiesblock= gb - ; comments } - (untyped_program_of_typed_program decorated_ast) - = 0 - then decorated_ast - else - raise_s - [%message - "Type checked AST does not match original AST. Please file a bug!" - (decorated_ast : typed_program)] - in - let check_correctness_invariant_validate = - Validate.map ~f:check_correctness_invariant - in - Validate.( - ok mk_typed_prog |> apply_to ufb |> apply_to udb |> apply_to utdb - |> apply_to upb |> apply_to utpb |> apply_to umb |> apply_to ugb - |> check_correctness_invariant_validate - |> get_with - ~with_ok:(fun ok -> Result.Ok (attach_warnings ok)) - ~with_errors:(fun errs -> Result.Error errs)) diff --git a/src/frontend/Semantic_check.mli b/src/frontend/Semantic_check.mli deleted file mode 100644 index 7d9c5b0d45..0000000000 --- a/src/frontend/Semantic_check.mli +++ /dev/null @@ -1,30 +0,0 @@ -(** Semantic validation of AST*) - -open Core_kernel - -val inferred_unsizedtype_of_indexed_exn : - loc:Middle.Location_span.t - -> Middle.UnsizedType.t - -> Ast.typed_expression Ast.index list - -> Middle.UnsizedType.t -(** Infers unsized type of an `Indexed` expression *) - -val semantic_check_binop_exn : - Middle.Location_span.t - -> Middle.Operator.t - -> Ast.typed_expression * Ast.typed_expression - -> Ast.typed_expression - -val semantic_check_program : - Ast.untyped_program - -> ( Ast.typed_program * Middle.Warnings.t list - , Middle.Semantic_error.t list ) - result -(** Performs semantic check on AST and returns original AST embellished with type decorations *) - -val check_that_all_functions_have_definition : bool ref -(** A switch to determine whether we check that all functions have a definition *) - -val model_name : string ref -(** A reference to hold the model name. Relevant for checking variable - clashes and used in code generation. *) diff --git a/src/middle/Semantic_error.ml b/src/frontend/Semantic_error.ml similarity index 96% rename from src/middle/Semantic_error.ml rename to src/frontend/Semantic_error.ml index 9a2fc92718..a09b4e926f 100644 --- a/src/middle/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -1,4 +1,5 @@ open Core_kernel +open Middle (** Type errors that may arise during semantic check *) module TypeError = struct @@ -36,16 +37,10 @@ module TypeError = struct | NonReturningFnExpectedReturningFound of string | NonReturningFnExpectedNonFnFound of string | NonReturningFnExpectedUndeclaredIdentFound of string - | IllTypedStanLibFunctionApp of + | IllTypedFunctionApp of string * UnsizedType.t list * (SignatureMismatch.signature_error list * bool) - | IllTypedUserDefinedFunctionApp of - string - * (UnsizedType.autodifftype * UnsizedType.t) list - * UnsizedType.returntype - * UnsizedType.t list - * SignatureMismatch.function_mismatch | IllTypedBinaryOperator of Operator.t * UnsizedType.t * UnsizedType.t | IllTypedPrefixOperator of Operator.t * UnsizedType.t | IllTypedPostfixOperator of Operator.t * UnsizedType.t @@ -173,12 +168,8 @@ module TypeError = struct "A non-returning function was expected but an undeclared identifier \ '%s' was supplied." fn_name - | IllTypedStanLibFunctionApp (name, arg_tys, errors) -> + | IllTypedFunctionApp (name, arg_tys, errors) -> SignatureMismatch.pp_signature_mismatch ppf (name, arg_tys, errors) - | IllTypedUserDefinedFunctionApp - (name, listed_tys, return_ty, arg_tys, error) -> - SignatureMismatch.pp_signature_mismatch ppf - (name, arg_tys, ([((return_ty, listed_tys), error)], false)) | IllTypedBinaryOperator (op, lt, rt) -> Fmt.pf ppf "Ill-typed arguments supplied to infix operator %a. Available \ @@ -387,7 +378,7 @@ module StatementError = struct "Function '%s' has already been declared. A definition is expected." name | FunDeclNoDefn -> - Fmt.pf ppf "Some function is declared without specifying a definition." + Fmt.pf ppf "Function is declared without specifying a definition." | FunDeclNeedsBlock -> Fmt.pf ppf "Function definitions must be wrapped in curly braces." | NonRealProbFunDef -> @@ -505,15 +496,8 @@ let nonreturning_fn_expected_nonfn_found loc name = let nonreturning_fn_expected_undeclaredident_found loc name = TypeError (loc, TypeError.NonReturningFnExpectedUndeclaredIdentFound name) -let illtyped_stanlib_fn_app loc name errors arg_tys = - TypeError (loc, TypeError.IllTypedStanLibFunctionApp (name, arg_tys, errors)) - -let illtyped_userdefined_fn_app loc name decl_arg_tys decl_return_ty error - arg_tys = - TypeError - ( loc - , TypeError.IllTypedUserDefinedFunctionApp - (name, decl_arg_tys, decl_return_ty, arg_tys, error) ) +let illtyped_fn_app loc name errors arg_tys = + TypeError (loc, TypeError.IllTypedFunctionApp (name, arg_tys, errors)) let illtyped_binary_op loc op lt rt = TypeError (loc, TypeError.IllTypedBinaryOperator (op, lt, rt)) diff --git a/src/middle/Semantic_error.mli b/src/frontend/Semantic_error.mli similarity index 95% rename from src/middle/Semantic_error.mli rename to src/frontend/Semantic_error.mli index b6b798818e..b264c61ff5 100644 --- a/src/middle/Semantic_error.mli +++ b/src/frontend/Semantic_error.mli @@ -1,3 +1,5 @@ +open Middle + type t val pp : Format.formatter -> t -> unit @@ -64,22 +66,13 @@ val nonreturning_fn_expected_nonfn_found : Location_span.t -> string -> t val nonreturning_fn_expected_undeclaredident_found : Location_span.t -> string -> t -val illtyped_stanlib_fn_app : +val illtyped_fn_app : Location_span.t -> string -> SignatureMismatch.signature_error list * bool -> UnsizedType.t list -> t -val illtyped_userdefined_fn_app : - Location_span.t - -> string - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> UnsizedType.returntype - -> SignatureMismatch.function_mismatch - -> UnsizedType.t list - -> t - val illtyped_binary_op : Location_span.t -> Operator.t -> UnsizedType.t -> UnsizedType.t -> t diff --git a/src/middle/SignatureMismatch.ml b/src/frontend/SignatureMismatch.ml similarity index 94% rename from src/middle/SignatureMismatch.ml rename to src/frontend/SignatureMismatch.ml index 96f6878100..b05bdf21c5 100644 --- a/src/middle/SignatureMismatch.ml +++ b/src/frontend/SignatureMismatch.ml @@ -1,4 +1,5 @@ open Core_kernel +open Middle module TypeMap = Core_kernel.Map.Make_using_comparator (UnsizedType) let set ctx key data = ctx := TypeMap.set !ctx ~key ~data @@ -148,17 +149,26 @@ and check_compatible_arguments depth args1 args2 = let check_compatible_arguments_mod_conv = check_compatible_arguments 0 let max_n_errors = 5 -let stan_math_returntype name args = - (* NB: Variadic arguments are special-cased in Semantic_check and not handled here *) +let extract_function_types f = + match f with + | Environment.({type_= UFun (args, return, _, mem); kind= `StanMath}) -> + Some (return, args, (fun x -> Ast.StanLib x), mem) + | {type_= UFun (args, return, _, mem); _} -> + Some (return, args, (fun x -> UserDefined x), mem) + | _ -> None + +let returntype env name args = + (* NB: Variadic arguments are special-cased in the typechecker and not handled here *) let name = Utils.stdlib_distribution_name name in - Hashtbl.find_multi Stan_math_signatures.stan_math_signatures name - |> List.sort ~compare:(fun (x, _, _) (y, _, _) -> + Environment.find env name + |> List.filter_map ~f:extract_function_types + |> List.sort ~compare:(fun (x, _, _, _) (y, _, _, _) -> UnsizedType.compare_returntype x y ) (* Check the least return type first in case there are multiple options (due to implicit UInt-UReal conversion), where UInt List.fold_until ~init:[] - ~f:(fun errors (rt, tys, _) -> + ~f:(fun errors (rt, tys, funkind_constructor, _) -> match check_compatible_arguments 0 tys args with - | None -> Stop (Ok rt) + | None -> Stop (Ok (rt, funkind_constructor)) | Some e -> Continue (((rt, tys), e) :: errors) ) ~finish:(fun errors -> let errors = diff --git a/src/middle/SignatureMismatch.mli b/src/frontend/SignatureMismatch.mli similarity index 84% rename from src/middle/SignatureMismatch.mli rename to src/frontend/SignatureMismatch.mli index 5baeb788fd..e2a7b1ffb5 100644 --- a/src/middle/SignatureMismatch.mli +++ b/src/frontend/SignatureMismatch.mli @@ -1,4 +1,5 @@ open Core_kernel +open Middle type type_mismatch = | DataOnlyError @@ -23,10 +24,13 @@ val check_compatible_arguments_mod_conv : -> (UnsizedType.autodifftype * UnsizedType.t) list -> function_mismatch option -val stan_math_returntype : - string +val returntype : + Environment.t + -> string -> (UnsizedType.autodifftype * UnsizedType.t) list - -> (UnsizedType.returntype, signature_error list * bool) result + -> ( UnsizedType.returntype * (bool Middle.Fun_kind.suffix -> Ast.fun_kind) + , signature_error list * bool ) + result val check_variadic_args : bool @@ -47,3 +51,5 @@ val pp_signature_mismatch : list * bool ) -> unit + +val compare_errors : function_mismatch -> function_mismatch -> int diff --git a/src/frontend/Symbol_table.ml b/src/frontend/Symbol_table.ml deleted file mode 100644 index a6d975ae21..0000000000 --- a/src/frontend/Symbol_table.ml +++ /dev/null @@ -1,79 +0,0 @@ -(** Symbol table to implement var map *) - -open Core_kernel - -(* TODO: I'm sure this implementation could be made more efficient if that's necessary. There's no need for all the string comparison. -We could just keep track of the count of the entry into the hash table and use that for comparison. *) -type 'a state = - { table: (string, 'a) Hashtbl.t - ; stack: string Stack.t - ; scopedepth: int ref - ; readonly: (string, unit) Hashtbl.t - ; isunassigned: (string, unit) Hashtbl.t - ; globals: (string, unit) Hashtbl.t } - -let initialize () = - { table= String.Table.create () - ; stack= Stack.create () - ; scopedepth= ref 0 - ; readonly= String.Table.create () - ; isunassigned= String.Table.create () - ; globals= String.Table.create () } - -let enter s str ty = - let _ : [`Duplicate | `Ok] = - if !(s.scopedepth) = 0 then Hashtbl.add s.globals ~key:str ~data:() - else `Ok - in - let _ : [`Duplicate | `Ok] = Hashtbl.add s.table ~key:str ~data:ty in - Stack.push s.stack str - -let look s str = Hashtbl.find s.table str - -let begin_scope s = - s.scopedepth := !(s.scopedepth) + 1 ; - Stack.push s.stack "-sentinel-new-scope-" - -(* using a string "-sentinel-new-scope-" here that can never be used as an identifier to indicate that new scope is entered *) -let end_scope s = - s.scopedepth := !(s.scopedepth) - 1 ; - while Stack.top_exn s.stack <> "-sentinel-new-scope-" do - (* we pop the stack down to where we entered the current scope and remove all variables defined since from the var map *) - Hashtbl.remove s.table (Stack.top_exn s.stack) ; - Hashtbl.remove s.readonly (Stack.top_exn s.stack) ; - Hashtbl.remove s.isunassigned (Stack.top_exn s.stack) ; - let _ : string = Stack.pop_exn s.stack in - () - done ; - let _ : string = Stack.pop_exn s.stack in - () - -let set_read_only s str = - let _ : [`Duplicate | `Ok] = Hashtbl.add s.readonly ~key:str ~data:() in - () - -let get_read_only s str = - match Hashtbl.find s.readonly str with Some () -> true | _ -> false - -let set_is_assigned s str = Hashtbl.remove s.isunassigned str - -let set_is_unassigned s str = - let _ : [`Duplicate | `Ok] = - if Hashtbl.mem s.isunassigned str then `Ok - else Hashtbl.add s.isunassigned ~key:str ~data:() - in - () - -let check_is_unassigned s str = Hashtbl.mem s.isunassigned str -let check_some_id_is_unassigned s = not (Hashtbl.length s.isunassigned = 0) - -let is_global s str = - match Hashtbl.find s.globals str with Some _ -> true | _ -> false - -let unsafe_clear_symbol_table s = - Hashtbl.clear s.table ; - Stack.clear s.stack ; - s.scopedepth := 0 ; - Hashtbl.clear s.readonly ; - Hashtbl.clear s.isunassigned ; - Hashtbl.clear s.globals diff --git a/src/frontend/Symbol_table.mli b/src/frontend/Symbol_table.mli deleted file mode 100644 index 60dae7583a..0000000000 --- a/src/frontend/Symbol_table.mli +++ /dev/null @@ -1,43 +0,0 @@ -(** Symbol table interface to implement var map *) - -type 'a state - -val initialize : unit -> 'a state -(** Creates a new symbol table *) - -val enter : 'a state -> string -> 'a -> unit -(** Enters a specified identifier with its specified type (or other) information - into a symbol table *) - -val look : 'a state -> string -> 'a option -(** Looks for an identifier in a symbol table and returns its information if found and None otherwise *) - -val begin_scope : 'a state -> unit -(** Used to start a new local scope which symbols added from now will end up in *) - -val end_scope : 'a state -> unit -(** Used to end a local scope, purging the symbol table of all symbols added in that scope *) - -val set_read_only : 'a state -> string -> unit -(** Used to add a read only label to an identifier *) - -val get_read_only : 'a state -> string -> bool -(** Used to check for a read only label for an identifier *) - -val set_is_assigned : 'a state -> string -> unit -(** Label an identifier as having been assigned to *) - -val set_is_unassigned : 'a state -> string -> unit -(** Label an identifier as not having been assigned to *) - -val check_is_unassigned : 'a state -> string -> bool -(** Check whether an identifier is labelled as unassigned *) - -val check_some_id_is_unassigned : 'a state -> bool -(** Used to check whether some identifier is labelled as unassigned *) - -val is_global : 'a state -> string -> bool -(** Used to check whether an identifier was declared in global scope *) - -val unsafe_clear_symbol_table : 'a state -> unit -(** Used to clear the whole symbol table *) diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml new file mode 100644 index 0000000000..352b6a28b8 --- /dev/null +++ b/src/frontend/Typechecker.ml @@ -0,0 +1,1493 @@ +(** a type/semantic checker for Stan ASTs + + Functions which begin with "check_" return a typed version of their input + Functions which begin with "verify_" return unit if a check succeeds, or else + throw an Errors.SemanticError exception. + Other functions which begin with "infer"/"calculate" vary. Usually they return + a value, but a few do have error conditions. + + All Error.SemanticError excpetions are caught by check_program + which turns the ast or exception into a Result.t for external usage + + A type environment (Env.t) is used to hold variables and functions, including + stan math functions. This is a functional map, meaning it is handled immutably. +*) + +open Core_kernel +open Middle +open Ast +module Env = Environment + +(* we only allow errors raised by this function *) +let error e = raise (Errors.SemanticError e) + +(* warnings are built up in a list *) +let warnings : Warnings.t list ref = ref [] + +let add_warning (span : Location_span.t) (message : string) = + warnings := (span, message) :: !warnings + +let attach_warnings x = (x, List.rev !warnings) + +(* model name - don't love this here *) +let model_name = ref "" +let check_that_all_functions_have_definition = ref true + +(* Record structure holding flags and other markers about context to be + used for error reporting. *) +type context_flags_record = + { current_block: Env.originblock + ; in_toplevel_decl: bool + ; in_fun_def: bool + ; in_returning_fun_def: bool + ; in_rng_fun_def: bool + ; in_lp_fun_def: bool + ; in_udf_dist_def: bool + ; loop_depth: int } + +let context block = + { current_block= block + ; in_toplevel_decl= false + ; in_fun_def= false + ; in_returning_fun_def= false + ; in_rng_fun_def= false + ; in_lp_fun_def= false + ; in_udf_dist_def= false + ; loop_depth= 0 } + +let calculate_autodifftype cf origin ut = + match origin with + | Env.(Param | TParam | Model | Functions) + when not (UnsizedType.contains_int ut || cf.current_block = GQuant) -> + UnsizedType.AutoDiffable + | _ -> DataOnly + +let arg_type x = (x.emeta.ad_level, x.emeta.type_) +let get_arg_types = List.map ~f:arg_type +let type_of_expr_typed ue = ue.emeta.type_ +let has_int_type ue = ue.emeta.type_ = UInt +let has_int_array_type ue = ue.emeta.type_ = UArray UInt + +let has_int_or_real_type ue = + match ue.emeta.type_ with UInt | UReal -> true | _ -> false + +(* -- General checks ---------------------------------------------- *) +let reserved_keywords = + [ "for"; "in"; "while"; "repeat"; "until"; "if"; "then"; "else"; "true" + ; "false"; "target"; "int"; "real"; "complex"; "void"; "vector"; "simplex" + ; "unit_vector"; "ordered"; "positive_ordered"; "row_vector"; "matrix" + ; "cholesky_factor_corr"; "cholesky_factor_cov"; "corr_matrix"; "cov_matrix" + ; "functions"; "model"; "data"; "parameters"; "quantities"; "transformed" + ; "generated"; "profile"; "return"; "break"; "continue"; "increment_log_prob" + ; "get_lp"; "print"; "reject"; "typedef"; "struct"; "var"; "export"; "extern" + ; "static"; "auto" ] + +let verify_identifier id : unit = + if id.name = !model_name then + Semantic_error.ident_is_model_name id.id_loc id.name |> error + else if + String.is_suffix id.name ~suffix:"__" + || List.mem reserved_keywords id.name ~equal:String.equal + then Semantic_error.ident_is_keyword id.id_loc id.name |> error + +let distribution_name_variants name = + if name = "multiply_log" || name = "binomial_coefficient_log" then [name] + else + (* this will have some duplicates, but preserves order better *) + match Utils.split_distribution_suffix name with + | Some (stem, "lpmf") | Some (stem, "lpdf") | Some (stem, "log") -> + [name; stem ^ "_lpmf"; stem ^ "_lpdf"; stem ^ "_log"] + | Some (stem, "lcdf") | Some (stem, "cdf_log") -> + [name; stem ^ "_lcdf"; stem ^ "_cdf_log"] + | Some (stem, "lccdf") | Some (stem, "ccdf_log") -> + [name; stem ^ "_lccdf"; stem ^ "_ccdf_log"] + | _ -> [name] + +(** verify that the variable being declared is previous unused. + allowed to shadow StanLib *) +let verify_name_fresh_var loc tenv name = + if Utils.is_unnormalized_distribution name then + Semantic_error.ident_has_unnormalized_suffix loc name |> error + else if + Env.mem tenv name + && not (Stan_math_signatures.is_stan_math_function_name name) + then Semantic_error.ident_in_use loc name |> error + +(** verify that the variable being declared is previous unused. + not allowed shadowing/overloading (yet)*) +let verify_name_fresh_udf loc tenv name = + if + Stan_math_signatures.is_stan_math_function_name name + (* variadic functions are currently not in math sigs *) + || Stan_math_signatures.is_reduce_sum_fn name + || Stan_math_signatures.is_variadic_ode_fn name + then Semantic_error.ident_is_stanmath_name loc name |> error + else if Utils.is_unnormalized_distribution name then + Semantic_error.udf_is_unnormalized_fn loc name |> error + else if Env.mem tenv name then + (* adapt for overloading later *) + Semantic_error.ident_in_use loc name |> error + +(** Checks that a variable/function name: + - a function/identifier does not have the _lupdf/_lupmf suffix + - is not already in use (for now) +*) +let verify_name_fresh tenv id ~is_udf = + let f = + if is_udf then verify_name_fresh_udf id.id_loc tenv + else verify_name_fresh_var id.id_loc tenv + in + List.iter ~f (distribution_name_variants id.name) + +let is_of_compatible_return_type rt1 srt2 = + UnsizedType.( + match (rt1, srt2) with + | Void, NoReturnType + |Void, Incomplete Void + |Void, Complete Void + |Void, AnyReturnType -> + true + | ReturnType UReal, Complete (ReturnType UInt) -> true + | ReturnType UComplex, Complete (ReturnType UReal) -> true + | ReturnType UComplex, Complete (ReturnType UInt) -> true + | ReturnType rt1, Complete (ReturnType rt2) -> rt1 = rt2 + | ReturnType _, AnyReturnType -> true + | _ -> false) + +(* -- Expressions ------------------------------------------------- *) +let check_ternary_if loc pe te fe = + match + (pe.emeta.type_, UnsizedType.common_type (te.emeta.type_, fe.emeta.type_)) + with + | UInt, Some type_ when not (UnsizedType.is_fun_type type_) -> + mk_typed_expression + ~expr:(TernaryIf (pe, te, fe)) + ~ad_level:(expr_ad_lub [pe; te; fe]) + ~type_ ~loc + | _, _ -> + Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_ + fe.emeta.type_ + |> error + +let check_binop loc op le re = + let rt = + [le; re] |> get_arg_types + |> Stan_math_signatures.operator_stan_math_return_type op + in + match rt with + | Some (ReturnType type_) -> + mk_typed_expression + ~expr:(BinOp (le, op, re)) + ~ad_level:(expr_ad_lub [le; re]) + ~type_ ~loc + | _ -> + Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_ + |> error + +let check_prefixop loc op te = + let rt = + Stan_math_signatures.operator_stan_math_return_type op [arg_type te] + in + match rt with + | Some (ReturnType type_) -> + mk_typed_expression + ~expr:(PrefixOp (op, te)) + ~ad_level:(expr_ad_lub [te]) + ~type_ ~loc + | _ -> Semantic_error.illtyped_prefix_op loc op te.emeta.type_ |> error + +let check_postfixop loc op te = + let rt = + Stan_math_signatures.operator_stan_math_return_type op [arg_type te] + in + match rt with + | Some (ReturnType type_) -> + mk_typed_expression + ~expr:(PostfixOp (te, op)) + ~ad_level:(expr_ad_lub [te]) + ~type_ ~loc + | _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error + +let check_variable cf loc tenv id = + match Env.find tenv (Utils.stdlib_distribution_name id.name) with + | [] -> + (* OCaml in these situations suggests similar names + We could too, if we did a fuzzy search on the keys in tenv + *) + Semantic_error.ident_not_in_scope loc id.name |> error + | {kind= `StanMath; _} :: _ -> + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype cf MathLibrary UMathLibraryFunction) + ~type_:UMathLibraryFunction ~loc + | {kind= `Variable {origin= Param | TParam | GQuant; _}; _} :: _ + when cf.in_toplevel_decl -> + Semantic_error.non_data_variable_size_decl loc |> error + | _ :: _ + when Utils.is_unnormalized_distribution id.name + && not + ( (cf.in_fun_def && (cf.in_udf_dist_def || cf.in_lp_fun_def)) + || cf.current_block = Model ) -> + Semantic_error.invalid_unnormalized_fn loc |> error + | {kind= `Variable {origin; _}; type_} :: _ -> + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype cf origin type_) + ~type_ ~loc + (* TODO - When it's time for overloading, will this need + some kind of filter/match on arg types? *) + | { kind= `UserDefined | `UserDeclared _ + ; type_= UFun (args, rt, FnLpdf _, mem_pattern) } + :: _ -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) + in + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype cf Functions type_) + ~type_ ~loc + | {kind= `UserDefined | `UserDeclared _; type_} :: _ -> + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype cf Functions type_) + ~type_ ~loc + +let get_consistent_types ad_level type_ es = + let f state e = + match state with + | Error e -> Error e + | Ok (ad, ty) -> ( + let ad = + if UnsizedType.autodifftype_can_convert e.emeta.ad_level ad then + e.emeta.ad_level + else ad + in + match UnsizedType.common_type (ty, e.emeta.type_) with + | Some ty -> Ok (ad, ty) + | None -> Error (ty, e.emeta) ) + in + List.fold ~init:(Ok (ad_level, type_)) ~f es + +let check_array_expr loc es = + match es with + | [] -> Semantic_error.empty_array loc |> error + | {emeta= {ad_level; type_; _}; _} :: elements -> ( + match get_consistent_types ad_level type_ elements with + | Error (ty, meta) -> + Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error + | Ok (ad_level, type_) -> + let type_ = UnsizedType.UArray type_ in + mk_typed_expression ~expr:(ArrayExpr es) ~ad_level ~type_ ~loc ) + +let check_rowvector loc es = + match es with + | {emeta= {ad_level; type_= UnsizedType.URowVector; _}; _} :: elements -> ( + match get_consistent_types ad_level URowVector elements with + | Ok (ad_level, _) -> + mk_typed_expression ~expr:(RowVectorExpr es) ~ad_level ~type_:UMatrix + ~loc + | Error (_, meta) -> + Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) + | _ -> ( + match get_consistent_types DataOnly UReal es with + | Ok (ad_level, _) -> + mk_typed_expression ~expr:(RowVectorExpr es) ~ad_level + ~type_:URowVector ~loc + | Error (_, meta) -> + Semantic_error.invalid_row_vector_types meta.loc meta.type_ |> error ) + +(* index checking *) + +let indexing_type idx = + match idx with + | Single {emeta= {type_= UnsizedType.UInt; _}; _} -> `Single + | _ -> `Multi + +let inferred_unsizedtype_of_indexed ~loc ut indices = + let rec aux type_ idcs = + match (type_, idcs) with + | _, [] -> type_ + | UnsizedType.UArray type_, `Single :: tl -> aux type_ tl + | UArray type_, `Multi :: tl -> aux type_ tl |> UnsizedType.UArray + | (UVector | URowVector), [`Single] | UMatrix, [`Single; `Single] -> + UnsizedType.UReal + | (UVector | URowVector | UMatrix), [`Multi] | UMatrix, [`Multi; `Multi] -> + type_ + | UMatrix, ([`Single] | [`Single; `Multi]) -> UnsizedType.URowVector + | UMatrix, [`Multi; `Single] -> UnsizedType.UVector + | UMatrix, _ :: _ :: _ :: _ + |(UVector | URowVector), _ :: _ :: _ + |(UInt | UReal | UComplex | UFun _ | UMathLibraryFunction), _ :: _ -> + Semantic_error.not_indexable loc ut (List.length indices) |> error + in + aux ut (List.map ~f:indexing_type indices) + +let inferred_ad_type_of_indexed at uindices = + UnsizedType.lub_ad_type + ( at + :: List.map + ~f:(function + | All -> UnsizedType.DataOnly + | Single ue1 | Upfrom ue1 | Downfrom ue1 -> + UnsizedType.lub_ad_type [at; ue1.emeta.ad_level] + | Between (ue1, ue2) -> + UnsizedType.lub_ad_type + [at; ue1.emeta.ad_level; ue2.emeta.ad_level]) + uindices ) + +(* function checking *) + +let verify_conddist_name loc id = + if + List.exists + ~f:(fun x -> String.is_suffix id.name ~suffix:x) + Utils.conditioning_suffices + then () + else Semantic_error.conditional_notation_not_allowed loc |> error + +let verify_fn_conditioning loc id = + if + List.exists + ~f:(fun suffix -> String.is_suffix id.name ~suffix) + Utils.conditioning_suffices + && not (String.is_suffix id.name ~suffix:"_cdf") + then Semantic_error.conditioning_required loc |> error + +(** `Target+=` can only be used in model and functions + with right suffix (same for tilde etc) +*) +let verify_fn_target_plus_equals cf loc id = + if + String.is_suffix id.name ~suffix:"_lp" + && not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + +(** Rng functions cannot be used in Tp or Model and only + in function defs with the right suffix +*) +let verify_fn_rng cf loc id = + if String.is_suffix id.name ~suffix:"_rng" && cf.in_toplevel_decl then + Semantic_error.invalid_decl_rng_fn loc |> error + else if + String.is_suffix id.name ~suffix:"_rng" + && ( (cf.in_fun_def && not cf.in_rng_fun_def) + || cf.current_block = TParam || cf.current_block = Model ) + then Semantic_error.invalid_rng_fn loc |> error + +(** unnormalized _lpdf/_lpmf functions can only be used in _lpdf/_lpmf/_lp udfs + or the model block +*) +let verify_unnormalized cf loc id = + if + Utils.is_unnormalized_distribution id.name + && not ((cf.in_fun_def && cf.in_udf_dist_def) || cf.current_block = Model) + then Semantic_error.invalid_unnormalized_fn loc |> error + +let mk_fun_app ~is_cond_dist (x, y, z) = + if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z) + +let check_fn ~is_cond_dist loc tenv id es = + match Env.find tenv (Utils.normalized_name id.name) with + | {kind= `Variable _; _} :: _ + (* variables can sometimes shadow stanlib functions, so we have to check this *) + when not + (Stan_math_signatures.is_stan_math_function_name + (Utils.normalized_name id.name)) -> + Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error + | [] -> + Semantic_error.returning_fn_expected_undeclaredident_found loc id.name + |> error + | _ (* a function *) -> ( + match SignatureMismatch.returntype tenv id.name (get_arg_types es) with + | Ok (Void, _) -> + Semantic_error.returning_fn_expected_nonreturning_found loc id.name + |> error + | Ok (ReturnType ut, fnk) -> + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (fnk (Fun_kind.suffix_from_name id.name), id, es)) + ~ad_level:(expr_ad_lub es) ~type_:ut ~loc + | Error x -> + es + |> List.map ~f:(fun e -> e.emeta.type_) + |> Semantic_error.illtyped_fn_app loc id.name x + |> error ) + +let check_reduce_sum ~is_cond_dist loc id es = + match es with + | { emeta= + { type_= + UnsizedType.UFun + (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _); _ + }; _ } + :: _ + when List.mem Stan_math_signatures.reduce_sum_slice_types + sliced_arg_fun_type ~equal:( = ) -> ( + let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in + let mandatory_fun_args = + [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] + in + match + SignatureMismatch.check_variadic_args true mandatory_args + mandatory_fun_args UReal (get_arg_types es) + with + | None -> + mk_typed_expression + ~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es)) + ~ad_level:(expr_ad_lub es) ~type_:UnsizedType.UReal ~loc + | Some (expected_args, err) -> + Semantic_error.illtyped_reduce_sum loc id.name + (List.map ~f:type_of_expr_typed es) + expected_args err + |> error ) + | _ -> + let mandatory_args = + UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] + in + let mandatory_fun_args = + UnsizedType. + [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] + in + let expected_args, err = + SignatureMismatch.check_variadic_args true mandatory_args + mandatory_fun_args UReal (get_arg_types es) + |> Option.value_exn + in + Semantic_error.illtyped_reduce_sum_generic loc id.name + (List.map ~f:type_of_expr_typed es) + expected_args err + |> error + +let check_variadic_ode ~is_cond_dist loc id es = + let optional_tol_mandatory_args = + if Stan_math_signatures.variadic_ode_adjoint_fn = id.name then + Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types + else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn id.name then + Stan_math_signatures.variadic_ode_tol_arg_types + else [] + in + let mandatory_arg_types = + Stan_math_signatures.variadic_ode_mandatory_arg_types + @ optional_tol_mandatory_args + in + match + SignatureMismatch.check_variadic_args false mandatory_arg_types + Stan_math_signatures.variadic_ode_mandatory_fun_args + Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types es) + with + | None -> + mk_typed_expression + ~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es)) + ~ad_level:(expr_ad_lub es) + ~type_:Stan_math_signatures.variadic_ode_return_type ~loc + | Some (expected_args, err) -> + Semantic_error.illtyped_variadic_ode loc id.name + (List.map ~f:type_of_expr_typed es) + expected_args err + |> error + +let check_fn ~is_cond_dist loc tenv id es = + if Stan_math_signatures.is_reduce_sum_fn id.name then + check_reduce_sum ~is_cond_dist loc id es + else if Stan_math_signatures.is_variadic_ode_fn id.name then + check_variadic_ode ~is_cond_dist loc id es + else check_fn ~is_cond_dist loc tenv id es + +let rec check_funapp loc cf tenv ~is_cond_dist id tes = + (* overloading will need to defer typechecking of arguments? *) + let name_check = + if is_cond_dist then verify_conddist_name else verify_fn_conditioning + in + let res = check_fn ~is_cond_dist loc tenv id tes in + verify_identifier id ; + name_check loc id ; + verify_fn_target_plus_equals cf loc id ; + verify_fn_rng cf loc id ; + verify_unnormalized cf loc id ; + res + +and check_indexed loc cf tenv e indices = + let tindices = List.map ~f:(check_index cf tenv) indices in + let te = check_expression cf tenv e in + let ad_level = inferred_ad_type_of_indexed te.emeta.ad_level tindices in + let type_ = inferred_unsizedtype_of_indexed ~loc te.emeta.type_ tindices in + mk_typed_expression ~expr:(Indexed (te, tindices)) ~ad_level ~type_ ~loc + +and check_index cf tenv = function + | All -> All + (* Check that indexes have int (container) type *) + | Single e -> + let te = check_expression cf tenv e in + if has_int_type te || has_int_array_type te then Single te + else + Semantic_error.int_intarray_or_range_expected te.emeta.loc + te.emeta.type_ + |> error + | Upfrom e -> check_expression_of_int_type cf tenv e "Range bound" |> Upfrom + | Downfrom e -> + check_expression_of_int_type cf tenv e "Range bound" |> Downfrom + | Between (e1, e2) -> + let le = check_expression_of_int_type cf tenv e1 "Range bound" in + let ue = check_expression_of_int_type cf tenv e2 "Range bound" in + Between (le, ue) + +and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) : + Ast.typed_expression = + let loc = emeta.loc in + let ce = check_expression cf tenv in + match expr with + | TernaryIf (e1, e2, e3) -> + let pe = ce e1 in + let te = ce e2 in + let fe = ce e3 in + check_ternary_if loc pe te fe + | BinOp (e1, op, e2) -> + let le = ce e1 in + let re = ce e2 in + let warn_int_division x y = + match (x.emeta.type_, y.emeta.type_, op) with + | UInt, UInt, Divide -> + let hint ppf () = + match (x.expr, y.expr) with + | IntNumeral x, _ -> + Fmt.pf ppf "%s.0 / %a" x Pretty_printing.pp_typed_expression + y + | _, Ast.IntNumeral y -> + Fmt.pf ppf "%a / %s.0" Pretty_printing.pp_typed_expression x + y + | _ -> + Fmt.pf ppf "%a * 1.0 / %a" + Pretty_printing.pp_typed_expression x + Pretty_printing.pp_typed_expression y + in + let s = + Fmt.strf + "@[@[Found int division:@]@ @[%a@]@,@[%a@]@ @[%a@]@,@[%a@]@]" + Pretty_printing.pp_expression {expr; emeta} Fmt.text + "Values will be rounded towards zero. If rounding is not \ + desired you can write the division as" + hint () Fmt.text + "If rounding is intended please use the integer division \ + operator %/%." + in + add_warning x.emeta.loc s + | _ -> () + in + warn_int_division le re ; check_binop loc op le re + | PrefixOp (op, e) -> ce e |> check_prefixop loc op + | PostfixOp (e, op) -> ce e |> check_postfixop loc op + | Variable id -> + verify_identifier id ; + check_variable cf loc tenv id + | IntNumeral s -> ( + match float_of_string_opt s with + | Some i when i < 2_147_483_648.0 -> + mk_typed_expression ~expr:(IntNumeral s) ~ad_level:DataOnly ~type_:UInt + ~loc + | _ -> Semantic_error.bad_int_literal loc |> error ) + | RealNumeral s -> + mk_typed_expression ~expr:(RealNumeral s) ~ad_level:DataOnly ~type_:UReal + ~loc + | ImagNumeral s -> + mk_typed_expression ~expr:(ImagNumeral s) ~ad_level:DataOnly + ~type_:UComplex ~loc + | GetLP -> + (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) + if + not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then + Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + else + mk_typed_expression ~expr:GetLP + ~ad_level:(calculate_autodifftype cf cf.current_block UReal) + ~type_:UReal ~loc + | GetTarget -> + (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) + if + not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then + Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + else + mk_typed_expression ~expr:GetTarget + ~ad_level:(calculate_autodifftype cf cf.current_block UReal) + ~type_:UReal ~loc + | ArrayExpr es -> es |> List.map ~f:ce |> check_array_expr loc + | RowVectorExpr es -> es |> List.map ~f:ce |> check_rowvector loc + | Paren e -> + let te = ce e in + mk_typed_expression ~expr:(Paren te) ~ad_level:te.emeta.ad_level + ~type_:te.emeta.type_ ~loc + | Indexed (e, indices) -> check_indexed loc cf tenv e indices + | FunApp ((), id, es) -> + es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:false id + | CondDistApp ((), id, es) -> + es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:true id + +and check_expression_of_int_type cf tenv e name = + let te = check_expression cf tenv e in + if has_int_type te then te + else Semantic_error.int_expected te.emeta.loc name te.emeta.type_ |> error + +let check_expression_of_int_or_real_type cf tenv e name = + let te = check_expression cf tenv e in + if has_int_or_real_type te then te + else + Semantic_error.int_or_real_expected te.emeta.loc name te.emeta.type_ + |> error + +let check_expression_of_scalar_or_type cf tenv t e name = + let te = check_expression cf tenv e in + if UnsizedType.is_scalar_type te.emeta.type_ || te.emeta.type_ = t then te + else + Semantic_error.scalar_or_type_expected te.emeta.loc name t te.emeta.type_ + |> error + +(* -- Statements ------------------------------------------------- *) +(* non returning functions *) +let verify_nrfn_target loc cf id = + if + String.is_suffix id.name ~suffix:"_lp" + && not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + +let check_nrfn loc tenv id es = + match Env.find tenv id.name with + | {kind= `Variable _; _} :: _ + (* variables can shadow stanlib functions, so we have to check this *) + when not (Stan_math_signatures.is_stan_math_function_name id.name) -> + Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name |> error + | [] -> + Semantic_error.nonreturning_fn_expected_undeclaredident_found loc id.name + |> error + | _ (* a function *) -> ( + match SignatureMismatch.returntype tenv id.name (get_arg_types es) with + | Ok (Void, fnk) -> + mk_typed_statement + ~stmt:(NRFunApp (fnk (Fun_kind.suffix_from_name id.name), id, es)) + ~return_type:NoReturnType ~loc + | Ok (ReturnType _, _) -> + Semantic_error.nonreturning_fn_expected_returning_found loc id.name + |> error + | Error x -> + es + |> List.map ~f:type_of_expr_typed + |> Semantic_error.illtyped_fn_app loc id.name x + |> error ) + +let check_nr_fn_app loc cf tenv id es = + let tes = List.map ~f:(check_expression cf tenv) es in + verify_identifier id ; + verify_nrfn_target loc cf id ; + check_nrfn loc tenv id tes + +(* assignments *) +let verify_assignment_read_only loc is_readonly id = + if is_readonly then + Semantic_error.cannot_assign_to_read_only loc id.name |> error + +(* Variables from previous blocks are read-only. + In particular, data and parameters never assigned to + *) +let verify_assignment_global loc cf block is_global id = + if (not is_global) || block = cf.current_block then () + else Semantic_error.cannot_assign_to_global loc id.name |> error + +let mk_assignment_from_indexed_expr assop lhs rhs = + Assignment + {assign_lhs= Ast.lvalue_of_expr lhs; assign_op= assop; assign_rhs= rhs} + +let check_assignment_operator loc assop lhs rhs = + let err op = + Semantic_error.illtyped_assignment loc op lhs.emeta.type_ rhs.emeta.type_ + in + let () = + match assop with + | Assign | ArrowAssign -> + if + UnsizedType.check_of_same_type_mod_array_conv "" lhs.emeta.type_ + rhs.emeta.type_ + then () + else err Operator.Equals |> error + | OperatorAssign op -> ( + let args = List.map ~f:arg_type [lhs; rhs] in + let return_type = + Stan_math_signatures.assignmentoperator_stan_math_return_type op args + in + match return_type with Some Void -> () | _ -> err op |> error ) + in + mk_typed_statement ~return_type:NoReturnType ~loc + ~stmt:(mk_assignment_from_indexed_expr assop lhs rhs) + +let check_assignment loc cf tenv assign_lhs assign_op assign_rhs = + let assign_id = Ast.id_of_lvalue assign_lhs in + let lhs = assign_lhs |> expr_of_lvalue |> check_expression cf tenv in + let rhs = check_expression cf tenv assign_rhs in + let block, global, readonly = + let var = Env.find tenv assign_id.name in + match var with + | {kind= `Variable {origin; global; readonly}; _} :: _ -> + (origin, global, readonly) + | {kind= `StanMath; _} :: _ -> (MathLibrary, true, false) + | {kind= `UserDefined | `UserDeclared _; _} :: _ -> (Functions, true, false) + | _ -> Semantic_error.ident_not_in_scope loc assign_id.name |> error + in + verify_assignment_global loc cf block global assign_id ; + verify_assignment_read_only loc readonly assign_id ; + check_assignment_operator loc assign_op lhs rhs + +(* target plus-equals / increment log-prob *) + +let verify_target_pe_expr_type loc e = + if UnsizedType.is_fun_type e.emeta.type_ then + Semantic_error.int_or_real_container_expected loc e.emeta.type_ |> error + +let verify_target_pe_usage loc cf = + if cf.in_lp_fun_def || cf.current_block = Model then () + else Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + +let check_target_pe loc cf tenv e = + let te = check_expression cf tenv e in + verify_target_pe_usage loc cf ; + verify_target_pe_expr_type loc te ; + mk_typed_statement ~stmt:(TargetPE te) ~return_type:NoReturnType ~loc + +let check_incr_logprob loc cf tenv e = + let te = check_expression cf tenv e in + verify_target_pe_usage loc cf ; + verify_target_pe_expr_type loc te ; + mk_typed_statement ~stmt:(IncrementLogProb te) ~return_type:NoReturnType ~loc + +(* tilde/sampling notation*) +let verify_sampling_pdf_pmf id = + if + String.( + is_suffix id.name ~suffix:"_lpdf" + || is_suffix id.name ~suffix:"_lpmf" + || is_suffix id.name ~suffix:"_lupdf" + || is_suffix id.name ~suffix:"_lupmf") + then Semantic_error.invalid_sampling_pdf_or_pmf id.id_loc |> error + +let verify_sampling_cdf_ccdf loc id = + if + String.( + is_suffix id.name ~suffix:"_cdf" || is_suffix id.name ~suffix:"_ccdf") + then Semantic_error.invalid_sampling_cdf_or_ccdf loc id.name |> error + +(* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) +let verify_valid_sampling_pos loc cf = + if cf.in_lp_fun_def || cf.current_block = Model then () + else Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + +let verify_sampling_distribution loc tenv id arguments = + let name = id.name + and argumenttypes = List.map ~f:arg_type arguments + and is_real_rt = function + | UnsizedType.ReturnType UReal -> true + | _ -> false + in + let is_name_w_suffix_sampling_dist suffix = + Stan_math_signatures.stan_math_returntype (name ^ suffix) argumenttypes + |> Option.value_map ~default:false ~f:is_real_rt + in + let is_sampling_dist_in_math = + List.exists ~f:is_name_w_suffix_sampling_dist + (Utils.distribution_suffices @ Utils.unnormalized_suffices) + && name <> "binomial_coefficient" + && name <> "multiply" + in + let is_name_w_suffix_udf_sampling_dist suffix = + let f = function + | Env.({ kind= `UserDefined | `UserDeclared _ + ; type_= UFun (listedtypes, ReturnType UReal, FnLpdf _, _) }) + when UnsizedType.check_compatible_arguments_mod_conv name listedtypes + argumenttypes -> + true + | _ -> false + in + List.exists (Env.find tenv (name ^ suffix)) ~f + in + let is_udf_sampling_dist = + List.exists ~f:is_name_w_suffix_udf_sampling_dist + (Utils.distribution_suffices @ Utils.unnormalized_suffices) + in + if is_sampling_dist_in_math || is_udf_sampling_dist then () + else Semantic_error.invalid_sampling_no_such_dist loc name |> error + +let is_cumulative_density_defined tenv id arguments = + let name = id.name + and argumenttypes = List.map ~f:arg_type arguments + and is_real_rt = function + | UnsizedType.ReturnType UReal -> true + | _ -> false + in + let is_real_rt_for_suffix suffix = + Stan_math_signatures.stan_math_returntype (name ^ suffix) argumenttypes + |> Option.value_map ~default:false ~f:is_real_rt + and valid_arg_types_for_suffix suffix = + let f = function + | Env.({ kind= `UserDefined | `UserDeclared _ + ; type_= UFun (listedtypes, ReturnType UReal, FnPlain, _) }) + when UnsizedType.check_compatible_arguments_mod_conv name listedtypes + argumenttypes -> + true + | _ -> false + in + List.exists (Env.find tenv (name ^ suffix)) ~f + in + ( is_real_rt_for_suffix "_lcdf" + || valid_arg_types_for_suffix "_lcdf" + || is_real_rt_for_suffix "_cdf_log" + || valid_arg_types_for_suffix "_cdf_log" ) + && ( is_real_rt_for_suffix "_lccdf" + || valid_arg_types_for_suffix "_lccdf" + || is_real_rt_for_suffix "_ccdf_log" + || valid_arg_types_for_suffix "_ccdf_log" ) + +let verify_can_truncate_distribution loc (arg : typed_expression) = function + | NoTruncate -> () + | _ -> + if UnsizedType.is_scalar_type arg.emeta.type_ then () + else Semantic_error.multivariate_truncation loc |> error + +let verify_sampling_cdf_defined loc tenv id truncation args = + let check e = is_cumulative_density_defined tenv id (e :: args) in + match truncation with + | NoTruncate -> () + | (TruncateUpFrom e | TruncateDownFrom e) when check e -> () + | TruncateBetween (e1, e2) when check e1 && check e2 -> () + | _ -> Semantic_error.invalid_truncation_cdf_or_ccdf loc |> error + +let check_truncation cf tenv truncation = + let check e = + check_expression_of_int_or_real_type cf tenv e "Truncation bound" + in + match truncation with + | NoTruncate -> NoTruncate + | TruncateUpFrom e -> check e |> TruncateUpFrom + | TruncateDownFrom e -> check e |> TruncateDownFrom + | TruncateBetween (e1, e2) -> (check e1, check e2) |> TruncateBetween + +let check_tilde loc cf tenv distribution truncation arg args = + let te = check_expression cf tenv arg in + let tes = List.map ~f:(check_expression cf tenv) args in + let ttrunc = check_truncation cf tenv truncation in + verify_identifier distribution ; + verify_sampling_pdf_pmf distribution ; + verify_valid_sampling_pos loc cf ; + verify_sampling_cdf_ccdf loc distribution ; + verify_sampling_distribution loc tenv distribution (te :: tes) ; + verify_sampling_cdf_defined loc tenv distribution ttrunc tes ; + verify_can_truncate_distribution loc te ttrunc ; + let stmt = Tilde {arg= te; distribution; args= tes; truncation= ttrunc} in + mk_typed_statement ~stmt ~loc ~return_type:NoReturnType + +(* Break and continue only occur in loops. *) +let check_break loc cf = + if cf.loop_depth = 0 then Semantic_error.break_outside_loop loc |> error + else mk_typed_statement ~stmt:Break ~return_type:NoReturnType ~loc + +let check_continue loc cf = + if cf.loop_depth = 0 then Semantic_error.continue_outside_loop loc |> error + else mk_typed_statement ~stmt:Continue ~return_type:NoReturnType ~loc + +let check_return loc cf tenv e = + if not cf.in_returning_fun_def then + Semantic_error.expression_return_outside_returning_fn loc |> error + else + let te = check_expression cf tenv e in + mk_typed_statement ~stmt:(Return te) + ~return_type:(Complete (ReturnType te.emeta.type_)) ~loc + +let check_returnvoid loc cf = + if (not cf.in_fun_def) || cf.in_returning_fun_def then + Semantic_error.void_ouside_nonreturning_fn loc |> error + else mk_typed_statement ~stmt:ReturnVoid ~return_type:(Complete Void) ~loc + +let check_printable cf tenv = function + | PString s -> PString s + (* Print/reject expressions cannot be of function type. *) + | PExpr e -> ( + let te = check_expression cf tenv e in + match te.emeta.type_ with + | UFun _ | UMathLibraryFunction -> + Semantic_error.not_printable te.emeta.loc |> error + | _ -> PExpr te ) + +let check_print loc cf tenv ps = + let tps = List.map ~f:(check_printable cf tenv) ps in + mk_typed_statement ~stmt:(Print tps) ~return_type:NoReturnType ~loc + +let check_reject loc cf tenv ps = + let tps = List.map ~f:(check_printable cf tenv) ps in + mk_typed_statement ~stmt:(Reject tps) ~return_type:AnyReturnType ~loc + +let check_skip loc = + mk_typed_statement ~stmt:Skip ~return_type:NoReturnType ~loc + +let rec stmt_is_escape {stmt; _} = + match stmt with + | Break | Continue | Reject _ | Return _ | ReturnVoid -> true + | _ -> false + +and list_until_escape xs = + let rec aux accu = function + | next :: next' :: _ when stmt_is_escape next' -> + List.rev (next' :: next :: accu) + | next :: rest -> aux (next :: accu) rest + | [] -> List.rev accu + in + aux [] xs + +let returntype_leastupperbound loc rt1 rt2 = + match (rt1, rt2) with + | UnsizedType.ReturnType UReal, UnsizedType.ReturnType UInt + |ReturnType UInt, ReturnType UReal -> + UnsizedType.ReturnType UReal + | _, _ when rt1 = rt2 -> rt2 + | _ -> Semantic_error.mismatched_return_types loc rt1 rt2 |> error + +let try_compute_block_statement_returntype loc srt1 srt2 = + match (srt1, srt2) with + | Complete rt1, Complete rt2 | Incomplete rt1, Complete rt2 -> + Complete (returntype_leastupperbound loc rt1 rt2) + | Incomplete rt1, Incomplete rt2 | Complete rt1, Incomplete rt2 -> + Incomplete (returntype_leastupperbound loc rt1 rt2) + | NoReturnType, NoReturnType -> NoReturnType + | AnyReturnType, Incomplete rt + |Complete rt, NoReturnType + |NoReturnType, Incomplete rt + |Incomplete rt, NoReturnType -> + Incomplete rt + | NoReturnType, Complete rt + |Complete rt, AnyReturnType + |Incomplete rt, AnyReturnType + |AnyReturnType, Complete rt -> + Complete rt + | AnyReturnType, NoReturnType + |NoReturnType, AnyReturnType + |AnyReturnType, AnyReturnType -> + AnyReturnType + +let try_compute_ifthenelse_statement_returntype loc srt1 srt2 = + match (srt1, srt2) with + | Complete rt1, Complete rt2 -> + returntype_leastupperbound loc rt1 rt2 |> Complete + | Incomplete rt1, Incomplete rt2 + |Complete rt1, Incomplete rt2 + |Incomplete rt1, Complete rt2 -> + returntype_leastupperbound loc rt1 rt2 |> Incomplete + | AnyReturnType, NoReturnType + |NoReturnType, AnyReturnType + |NoReturnType, NoReturnType -> + NoReturnType + | AnyReturnType, Incomplete rt + |Incomplete rt, AnyReturnType + |Complete rt, NoReturnType + |NoReturnType, Complete rt + |NoReturnType, Incomplete rt + |Incomplete rt, NoReturnType -> + Incomplete rt + | Complete rt, AnyReturnType | AnyReturnType, Complete rt -> Complete rt + | AnyReturnType, AnyReturnType -> AnyReturnType + +(* statements which contain statements, and therefore need to be mutually recursive + with check_statement +*) +let rec check_if_then_else loc cf tenv pred_e s_true s_false_opt = + (* we don't need these nested type environments *) + let _, ts_true = check_statement cf tenv s_true in + let ts_false_opt = + s_false_opt |> Option.map ~f:(check_statement cf tenv) |> Option.map ~f:snd + in + let te = + check_expression_of_int_or_real_type cf tenv pred_e + "Condition in conditional" + in + let stmt = IfThenElse (te, ts_true, ts_false_opt) in + let srt1 = ts_true.smeta.return_type in + let srt2 = + ts_false_opt + |> Option.map ~f:(fun s -> s.smeta.return_type) + |> Option.value ~default:NoReturnType + in + let return_type = + try_compute_ifthenelse_statement_returntype loc srt1 srt2 + in + mk_typed_statement ~stmt ~return_type ~loc + +and check_while loc cf tenv cond_e loop_body = + let _, ts = + check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body + and te = + check_expression_of_int_or_real_type cf tenv cond_e + "Condition in while-loop" + in + mk_typed_statement + ~stmt:(While (te, ts)) + ~return_type:ts.smeta.return_type ~loc + +and check_for loc cf tenv loop_var lower_bound_e upper_bound_e loop_body = + let te1 = + check_expression_of_int_type cf tenv lower_bound_e + "Lower bound of for-loop" + and te2 = + check_expression_of_int_type cf tenv upper_bound_e + "Upper bound of for-loop" + in + verify_identifier loop_var ; + let ts = check_loop_body cf tenv loop_var UnsizedType.UInt loop_body in + mk_typed_statement + ~stmt: + (For + { loop_variable= loop_var + ; lower_bound= te1 + ; upper_bound= te2 + ; loop_body= ts }) + ~return_type:ts.smeta.return_type ~loc + +and check_foreach_loop_identifier_type loc ty = + match ty with + | UnsizedType.UArray ut -> ut + | UVector | URowVector | UMatrix -> UnsizedType.UReal + | _ -> Semantic_error.array_vector_rowvector_matrix_expected loc ty |> error + +and check_foreach loc cf tenv loop_var foreach_e loop_body = + let te = check_expression cf tenv foreach_e in + verify_identifier loop_var ; + let loop_var_ty = + check_foreach_loop_identifier_type te.emeta.loc te.emeta.type_ + in + let ts = check_loop_body cf tenv loop_var loop_var_ty loop_body in + mk_typed_statement + ~stmt:(ForEach (loop_var, te, ts)) + ~return_type:ts.smeta.return_type ~loc + +and check_loop_body cf tenv loop_var loop_var_ty loop_body = + verify_name_fresh tenv loop_var ~is_udf:false ; + (* Add to type environment as readonly. + Check that function args and loop identifiers are not modified in + function. (passed by const ref) + *) + let tenv = + Env.add tenv loop_var.name loop_var_ty + (`Variable {origin= cf.current_block; global= false; readonly= true}) + in + snd (check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body) + +and check_block loc cf tenv stmts = + let _, checked_stmts = + List.fold_map stmts ~init:tenv ~f:(check_statement cf) + in + let stmts = list_until_escape checked_stmts in + let return_type = + stmts + |> List.map ~f:(fun s -> s.smeta.return_type) + |> List.fold ~init:NoReturnType + ~f:(try_compute_block_statement_returntype loc) + in + mk_typed_statement ~stmt:(Block stmts) ~return_type ~loc + +and check_profile loc cf tenv name stmts = + let _, checked_stmts = + List.fold_map stmts ~init:tenv ~f:(check_statement cf) + in + let stmts = list_until_escape checked_stmts in + let return_type = + stmts + |> List.map ~f:(fun s -> s.smeta.return_type) + |> List.fold ~init:NoReturnType + ~f:(try_compute_block_statement_returntype loc) + in + mk_typed_statement ~stmt:(Profile (name, stmts)) ~return_type ~loc + +(* variable declarations *) +and verify_valid_transformation_for_type loc is_global sized_ty trans = + let is_real {emeta; _} = emeta.type_ = UReal in + let is_real_transformation = + match trans with + | Transformation.Lower e -> is_real e + | Upper e -> is_real e + | LowerUpper (e1, e2) -> is_real e1 || is_real e2 + | _ -> false + in + if is_global && sized_ty = SizedType.SInt && is_real_transformation then + Semantic_error.non_int_bounds loc |> error ; + let is_transformation = + match trans with Transformation.Identity -> false | _ -> true + in + if + is_global + && SizedType.(inner_type sized_ty = SComplex) + && is_transformation + then Semantic_error.complex_transform loc |> error + +and verify_transformed_param_ty loc cf is_global unsized_ty = + if + is_global + && (cf.current_block = Param || cf.current_block = TParam) + && UnsizedType.contains_int unsized_ty + then Semantic_error.transformed_params_int loc |> error + +and check_sizedtype cf tenv sizedty = + let check e msg = check_expression_of_int_type cf tenv e msg in + match sizedty with + | SizedType.SInt -> SizedType.SInt + | SReal -> SReal + | SComplex -> SComplex + | SVector (mem_pattern, e) -> + let te = check e "Vector sizes" in + SizedType.SVector (mem_pattern, te) + | SRowVector (mem_pattern, e) -> + let te = check e "Row vector sizes" in + SizedType.SRowVector (mem_pattern, te) + | SMatrix (mem_pattern, e1, e2) -> + let te1 = check e1 "Matrix sizes" in + let te2 = check e2 "Matrix sizes" in + SizedType.SMatrix (mem_pattern, te1, te2) + | SArray (st, e) -> + let tst = check_sizedtype cf tenv st in + let te = check e "Array sizes" in + SizedType.SArray (tst, te) + +and check_var_decl_initial_value loc cf tenv id init_val_opt = + match init_val_opt with + | Some e -> ( + let stmt = + Assignment + { assign_lhs= {lval= LVariable id; lmeta= {loc}} + ; assign_op= Assign + ; assign_rhs= e } + in + mk_untyped_statement ~loc ~stmt + |> check_statement cf tenv |> snd + |> fun ts -> + match (ts.stmt, ts.smeta.return_type) with + | Assignment {assign_rhs= ue; _}, NoReturnType -> Some ue + | _ -> + raise_s + [%message "Internal error: check_var_decl: `Assignment` expected."] + ) + | None -> None + +and check_transformation cf tenv ut trans = + let check e msg = check_expression_of_scalar_or_type cf tenv ut e msg in + match trans with + | Transformation.Identity -> Transformation.Identity + | Lower e -> check e "Lower bound" |> Lower + | Upper e -> check e "Upper bound" |> Upper + | LowerUpper (e1, e2) -> + (check e1 "Lower bound", check e2 "Upper bound") |> LowerUpper + | Offset e -> check e "Offset" |> Offset + | Multiplier e -> check e "Multiplier" |> Multiplier + | OffsetMultiplier (e1, e2) -> + (check e1 "Offset", check e2 "Multiplier") |> OffsetMultiplier + | Ordered -> Ordered + | PositiveOrdered -> PositiveOrdered + | Simplex -> Simplex + | UnitVector -> UnitVector + | CholeskyCorr -> CholeskyCorr + | CholeskyCov -> CholeskyCov + | Correlation -> Correlation + | Covariance -> Covariance + +and check_var_decl loc cf tenv sized_ty trans id init is_global = + let checked_type = + check_sizedtype {cf with in_toplevel_decl= is_global} tenv sized_ty + in + let unsized_type = SizedType.to_unsized checked_type in + let checked_trans = check_transformation cf tenv unsized_type trans in + verify_identifier id ; + verify_name_fresh tenv id ~is_udf:false ; + let tenv = + Env.add tenv id.name unsized_type + (`Variable + {origin= cf.current_block; global= is_global; readonly= false}) + in + let tinit = check_var_decl_initial_value loc cf tenv id init in + verify_valid_transformation_for_type loc is_global checked_type checked_trans ; + verify_transformed_param_ty loc cf is_global unsized_type ; + let stmt = + VarDecl + { decl_type= Sized checked_type + ; transformation= checked_trans + ; identifier= id + ; initial_value= tinit + ; is_global } + in + (tenv, mk_typed_statement ~stmt ~loc ~return_type:NoReturnType) + +(* function definitions *) +and exists_matching_fn_declared tenv id arg_tys rt = + let f = function + | Env.({kind= `UserDeclared _; type_= UFun (listedtypes, rt', _, _)}) + when arg_tys = listedtypes && rt = rt' -> + true + | _ -> false + in + List.exists (Env.find tenv id.name) ~f + +and verify_fundef_overloaded loc tenv id arg_tys rt = + (* User defined functions cannot be overloaded at the moment + *) + if exists_matching_fn_declared tenv id arg_tys rt then () + else + let f = function Env.({kind= `UserDeclared _; _}) -> true | _ -> false in + if List.exists ~f (Env.find tenv id.name) then + (* a function of the same name but different signature *) + Env.find tenv id.name |> List.hd + |> Option.map ~f:(fun x -> x.type_) + |> Semantic_error.mismatched_fn_def_decl loc id.name + |> error + else verify_name_fresh tenv id ~is_udf:true + +and get_fn_decl_or_defn loc tenv id arg_tys rt body = + match body with + | {stmt= Skip; _} -> + if exists_matching_fn_declared tenv id arg_tys rt then + Semantic_error.fn_decl_without_def loc |> error + else `UserDeclared id.id_loc + | _ -> `UserDefined + +and verify_fundef_dist_rt loc id return_ty = + let is_dist = + List.exists + ~f:(fun x -> String.is_suffix id.name ~suffix:x) + Utils.conditioning_suffices_w_log + in + if is_dist then + match return_ty with + | UnsizedType.ReturnType UReal -> () + | _ -> Semantic_error.non_real_prob_fn_def loc |> error + +and verify_pdf_fundef_first_arg_ty loc id arg_tys = + if String.is_suffix id.name ~suffix:"_lpdf" then + let rt = List.hd arg_tys |> Option.map ~f:snd in + match rt with + | Some rt when UnsizedType.is_real_type rt -> () + | _ -> Semantic_error.prob_density_non_real_variate loc rt |> error + +and verify_pmf_fundef_first_arg_ty loc id arg_tys = + if String.is_suffix id.name ~suffix:"_lpmf" then + let rt = List.hd arg_tys |> Option.map ~f:snd in + match rt with + | Some rt when UnsizedType.is_int_type rt -> () + | _ -> Semantic_error.prob_mass_non_int_variate loc rt |> error + +and verify_fundef_distinct_arg_ids loc arg_names = + let dup_exists l = + List.find_a_dup ~compare:String.compare l |> Option.is_some + in + if dup_exists arg_names then Semantic_error.duplicate_arg_names loc |> error + +and verify_fundef_return_tys loc return_type body = + if + body.stmt = Skip + || is_of_compatible_return_type return_type body.smeta.return_type + then () + else Semantic_error.incompatible_return_types loc |> error + +and add_function tenv name type_ defined = + (* if we're providing a definition, we remove prior declarations + to simplify the environment *) + if defined = `UserDefined then + let existing_defns = Env.find tenv name in + let defns = + List.filter + ~f:(function + | Env.({kind= `UserDeclared _; type_= type'}) when type' = type_ -> + false + | _ -> true) + existing_defns + in + let new_fn = Env.{kind= `UserDefined; type_} in + Env.set_raw tenv name (new_fn :: defns) + else Env.add tenv name type_ defined + +and check_fundef loc cf tenv return_ty id args body = + List.iter args ~f:(fun (_, _, id) -> verify_identifier id) ; + verify_identifier id ; + let arg_types = List.map ~f:(fun (w, y, _) -> (w, y)) args in + let arg_identifiers = List.map ~f:(fun (_, _, z) -> z) args in + let arg_names = List.map ~f:(fun x -> x.name) arg_identifiers in + verify_fundef_overloaded loc tenv id arg_types return_ty ; + let defined = get_fn_decl_or_defn loc tenv id arg_types return_ty body in + verify_fundef_dist_rt loc id return_ty ; + verify_pdf_fundef_first_arg_ty loc id arg_types ; + verify_pmf_fundef_first_arg_ty loc id arg_types ; + let tenv = + add_function tenv id.name + (UFun (arg_types, return_ty, Fun_kind.suffix_from_name id.name, AoS)) + defined + in + List.iter + ~f:(fun id -> verify_name_fresh tenv id ~is_udf:false) + arg_identifiers ; + verify_fundef_distinct_arg_ids loc arg_names ; + (* We treat DataOnly arguments as if they are data and AutoDiffable arguments + as if they are parameters, for the purposes of type checking. + *) + let arg_types_internal = + List.map + ~f:(function + | UnsizedType.DataOnly, ut -> (Env.Data, ut) + | AutoDiffable, ut -> (Param, ut)) + arg_types + in + let tenv_body = + List.fold2_exn arg_names arg_types_internal ~init:tenv + ~f:(fun env name (origin, typ) -> + Env.add env name typ + (* readonly so that function args and loop identifiers + are not modified in function. (passed by const ref) *) + (`Variable {origin; readonly= true; global= false}) ) + in + let context = + let is_udf_dist name = + List.exists + ~f:(fun suffix -> String.is_suffix name ~suffix) + Utils.distribution_suffices + in + { cf with + in_fun_def= true + ; in_rng_fun_def= String.is_suffix id.name ~suffix:"_rng" + ; in_lp_fun_def= String.is_suffix id.name ~suffix:"_lp" + ; in_udf_dist_def= is_udf_dist id.name + ; in_returning_fun_def= return_ty <> Void } + in + let _, checked_body = check_statement context tenv_body body in + verify_fundef_return_tys loc return_ty checked_body ; + let stmt = + FunDef + {returntype= return_ty; funname= id; arguments= args; body= checked_body} + in + (* NB: **not** tenv_body, so args don't leak out *) + (tenv, mk_typed_statement ~return_type:NoReturnType ~loc ~stmt) + +and check_statement (cf : context_flags_record) (tenv : Env.t) + (s : Ast.untyped_statement) : Env.t * typed_statement = + let loc = s.smeta.loc in + match s.stmt with + | NRFunApp (_, id, es) -> (tenv, check_nr_fn_app loc cf tenv id es) + | Assignment {assign_lhs; assign_op; assign_rhs} -> + (tenv, check_assignment loc cf tenv assign_lhs assign_op assign_rhs) + | TargetPE e -> (tenv, check_target_pe loc cf tenv e) + | IncrementLogProb e -> (tenv, check_incr_logprob loc cf tenv e) + | Tilde {arg; distribution; args; truncation} -> + (tenv, check_tilde loc cf tenv distribution truncation arg args) + | Break -> (tenv, check_break loc cf) + | Continue -> (tenv, check_continue loc cf) + | Return e -> (tenv, check_return loc cf tenv e) + | ReturnVoid -> (tenv, check_returnvoid loc cf) + | Print ps -> (tenv, check_print loc cf tenv ps) + | Reject ps -> (tenv, check_reject loc cf tenv ps) + | Skip -> (tenv, check_skip loc) + (* the following can contain further statements *) + | IfThenElse (e, s1, os2) -> (tenv, check_if_then_else loc cf tenv e s1 os2) + | While (e, s) -> (tenv, check_while loc cf tenv e s) + | For {loop_variable; lower_bound; upper_bound; loop_body} -> + ( tenv + , check_for loc cf tenv loop_variable lower_bound upper_bound loop_body + ) + | ForEach (id, e, s) -> (tenv, check_foreach loc cf tenv id e s) + | Block stmts -> (tenv, check_block loc cf tenv stmts) + | Profile (name, vdsl) -> (tenv, check_profile loc cf tenv name vdsl) + | VarDecl {decl_type= Unsized _; _} -> + (* currently unallowed by parser *) + raise_s [%message "Don't support unsized declarations yet."] + (* these two are special in that they're allowed to change the type environment *) + | VarDecl + { decl_type= Sized st + ; transformation + ; identifier + ; initial_value + ; is_global } -> + check_var_decl loc cf tenv st transformation identifier initial_value + is_global + | FunDef {returntype; funname; arguments; body} -> + check_fundef loc cf tenv returntype funname arguments body + +let verify_fun_def_body_in_block = function + | {stmt= FunDef {body= {stmt= Block _; _}; _}; _} + |{stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> + () + | {stmt= FunDef {body= {stmt= _; smeta}; _}; _} -> + Semantic_error.fn_decl_needs_block smeta.loc |> error + | _ -> () + +let verify_functions_have_defn tenv function_block_stmts_opt = + let error_on_undefined funs = + List.iter funs ~f:(fun f -> + match f with + | Env.({kind= `UserDeclared loc; _}) -> + Semantic_error.fn_decl_without_def loc |> error + | _ -> () ) + in + if !check_that_all_functions_have_definition then + Env.iter tenv error_on_undefined ; + match function_block_stmts_opt with + | Some {stmts= []; _} | None -> () + | Some {stmts= ls; _} -> List.iter ~f:verify_fun_def_body_in_block ls + +let check_toplevel_block block tenv stmts_opt = + let cf = context block in + match stmts_opt with + | Some {stmts; xloc} -> + let tenv', stmts = + List.fold_map stmts ~init:tenv ~f:(check_statement cf) + in + (tenv', Some {stmts; xloc}) + | None -> (tenv, None) + +let verify_correctness_invariant (ast : untyped_program) + (decorated_ast : typed_program) = + let detyped = untyped_program_of_typed_program decorated_ast in + if compare_untyped_program ast detyped = 0 then () + else + raise_s + [%message + "Type checked AST does not match original AST. Please file a bug!" + (detyped : untyped_program) + (ast : untyped_program)] + +let check_program_exn + ( { functionblock= fb + ; datablock= db + ; transformeddatablock= tdb + ; parametersblock= pb + ; transformedparametersblock= tpb + ; modelblock= mb + ; generatedquantitiesblock= gqb + ; comments } as ast ) = + (* create a new type environment which has only stan-math functions *) + let tenv = Env.create () in + let tenv, typed_fb = check_toplevel_block Functions tenv fb in + verify_functions_have_defn tenv typed_fb ; + let tenv, typed_db = check_toplevel_block Data tenv db in + let tenv, typed_tdb = check_toplevel_block TData tenv tdb in + let tenv, typed_pb = check_toplevel_block Param tenv pb in + let tenv, typed_tpb = check_toplevel_block TParam tenv tpb in + let _, typed_mb = check_toplevel_block Model tenv mb in + let _, typed_gqb = check_toplevel_block GQuant tenv gqb in + let prog = + { functionblock= typed_fb + ; datablock= typed_db + ; transformeddatablock= typed_tdb + ; parametersblock= typed_pb + ; transformedparametersblock= typed_tpb + ; modelblock= typed_mb + ; generatedquantitiesblock= typed_gqb + ; comments } + in + verify_correctness_invariant ast prog ; + attach_warnings prog + +let check_program ast = + try Result.Ok (check_program_exn ast) with Errors.SemanticError err -> + Result.Error err diff --git a/src/frontend/Typechecker.mli b/src/frontend/Typechecker.mli new file mode 100644 index 0000000000..3cc21e9f3d --- /dev/null +++ b/src/frontend/Typechecker.mli @@ -0,0 +1,37 @@ +(** a type/semantic checker for Stan ASTs + + Functions which begin with "check_" return a typed version of their input + Functions which begin with "verify_" return unit if a check succeeds, or else + throw an [Errors.SemanticError] exception. + Other functions which begin with "infer"/"calculate" vary. Usually they return + a value, but a few do have error conditions. + + All [Error.SemanticError] excpetions are caught by check_program + which turns the ast or exception into a [Result.t] for external usage + + A type environment {!val:Environment.t} is used to hold variables and functions, including + Stan math functions. This is a functional map, meaning it is handled immutably. +*) + +open Ast + +val check_program_exn : untyped_program -> typed_program * Warnings.t list +(** + Type check a full Stan program. + Can raise [Errors.SemanticError] +*) + +val check_program : + untyped_program -> (typed_program * Warnings.t list, Semantic_error.t) result +(** + The safe version of [check_program_exn]. This catches + all [Errors.SemanticError] exceptions and converts them + into a [Result.t] +*) + +val model_name : string ref +(** A reference to hold the model name. Relevant for checking variable + clashes and used in code generation. *) + +val check_that_all_functions_have_definition : bool ref +(** A switch to determine whether we check that all functions have a definition *) diff --git a/src/middle/Warnings.ml b/src/frontend/Warnings.ml similarity index 84% rename from src/middle/Warnings.ml rename to src/frontend/Warnings.ml index b61617238f..615646ca92 100644 --- a/src/middle/Warnings.ml +++ b/src/frontend/Warnings.ml @@ -1,3 +1,6 @@ +module Location_span = Middle.Location_span +module Location = Middle.Location + type t = Location_span.t * string let pp ?printed_filename ppf (span, message) = diff --git a/src/middle/Warnings.mli b/src/frontend/Warnings.mli similarity index 71% rename from src/middle/Warnings.mli rename to src/frontend/Warnings.mli index a41dd84615..30210d4fbd 100644 --- a/src/middle/Warnings.mli +++ b/src/frontend/Warnings.mli @@ -1,4 +1,4 @@ -type t = Location_span.t * string +type t = Middle.Location_span.t * string val pp : ?printed_filename:string -> t Fmt.t val pp_warnings : ?printed_filename:string -> t list Fmt.t diff --git a/src/frontend/lexer.mll b/src/frontend/lexer.mll index 95b673e6d8..6b46195ffd 100644 --- a/src/frontend/lexer.mll +++ b/src/frontend/lexer.mll @@ -2,7 +2,6 @@ { module Stack = Core_kernel.Stack - module Errors = Middle.Errors open Lexing open Debugging open Preprocessor diff --git a/src/middle/Middle.ml b/src/middle/Middle.ml index 164c005cdc..c184819456 100644 --- a/src/middle/Middle.ml +++ b/src/middle/Middle.ml @@ -2,7 +2,6 @@ (MIR) of Stan *) -module Errors = Errors module Location = Location module Location_span = Location_span module Operator = Operator @@ -14,11 +13,8 @@ module Expr = Expr module UnsizedType = UnsizedType module SizedType = SizedType module Type = Type -module Semantic_error = Semantic_error -module SignatureMismatch = SignatureMismatch module Stmt = Stmt module Program = Program module Stan_math_signatures = Stan_math_signatures module Utils = Utils -module Warnings = Warnings module Transformation = Transformation diff --git a/src/middle/UnsizedType.ml b/src/middle/UnsizedType.ml index 4a110ecee5..a04af343a3 100644 --- a/src/middle/UnsizedType.ml +++ b/src/middle/UnsizedType.ml @@ -149,7 +149,7 @@ let is_int_type = function UInt | UArray UInt -> true | _ -> false let is_eigen_type ut = match ut with UVector | URowVector | UMatrix -> true | _ -> false -let is_fun_type = function UFun _ -> true | _ -> false +let is_fun_type = function UFun _ | UMathLibraryFunction -> true | _ -> false (** Detect if type contains an integer *) let rec contains_int ut = diff --git a/src/middle/Utils.ml b/src/middle/Utils.ml index f862682a4d..8998f41372 100644 --- a/src/middle/Utils.ml +++ b/src/middle/Utils.ml @@ -19,6 +19,14 @@ let unnormalized_suffix = function | "_lpmf" -> "_lupmf" | x -> x +(** A wrapper around String.rsplit2 which handles _cdf_log and _ccdf_log *) +let split_distribution_suffix (name : string) : (string * string) option = + let open String in + if is_suffix ~suffix:"_cdf_log" name then Some (drop_suffix name 8, "cdf_log") + else if is_suffix ~suffix:"_ccdf_log" name then + Some (drop_suffix name 9, "ccdf_log") + else rsplit2 ~on:'_' name + let is_distribution_name s = (not ( String.is_suffix s ~suffix:"_cdf_log" diff --git a/src/stan2tfp/Stan2tfp.ml b/src/stan2tfp/Stan2tfp.ml index d8afc0f8ea..8ae401b950 100644 --- a/src/stan2tfp/Stan2tfp.ml +++ b/src/stan2tfp/Stan2tfp.ml @@ -22,8 +22,7 @@ let set_model_file s = match !model_file with | "" -> model_file := s ; - Semantic_check.model_name := - remove_dotstan (Filename.basename s) ^ "_model" + Typechecker.model_name := remove_dotstan (Filename.basename s) ^ "_model" | _ -> raise_s [%message "Can only pass in one model file."] let main () = @@ -31,7 +30,7 @@ let main () = let mir = !model_file |> Frontend_utils.get_ast_or_exit |> Frontend_utils.type_ast_or_exit - |> Ast_to_Mir.trans_prog !Semantic_check.model_name + |> Ast_to_Mir.trans_prog !Typechecker.model_name in if !dump_mir then mir |> Middle.Program.Typed.sexp_of_t |> Sexp.to_string_hum diff --git a/src/stan_math_backend/Mangle.ml b/src/stan_math_backend/Mangle.ml index 089e81e24d..b4ff093fbb 100644 --- a/src/stan_math_backend/Mangle.ml +++ b/src/stan_math_backend/Mangle.ml @@ -1,24 +1,8 @@ -(** Mangle variables which are C++ reserved words into - valid C++ identifiers. - - This is done in Transform_Mir. When one of these - names is emitted as a string, we use remove_prefix - such that this mangling is opaque to the user - - e.g., a cmdstan output file would still have a column - called "public", even if internally we called this - "_stan_public" - - NB: the use of a leading _ is essential, because - the lexer won't allow this in a user-created variable. -*) open Core_kernel let kwrds_prefix = "_stan_" let prefix_len = String.length kwrds_prefix -(* Used in code generation so that this - mangling is opaque to the interfaces (e.g. cmdstan) -*) let remove_prefix s = if String.length s >= prefix_len + 1 diff --git a/src/stan_math_backend/Mangle.mli b/src/stan_math_backend/Mangle.mli index ab0752e8ac..1df7a7e073 100644 --- a/src/stan_math_backend/Mangle.mli +++ b/src/stan_math_backend/Mangle.mli @@ -1,2 +1,20 @@ +(** Mangle variables which are C++ reserved words into + valid C++ identifiers. + + This is done in Transform_Mir. When one of these + names is emitted as a string, we use remove_prefix + such that this mangling is opaque to the user - + e.g., a cmdstan output file would still have a column + called "public", even if internally we called this + "_stan_public" + + NB: the use of a leading _ is essential, because + the lexer won't allow this in a user-created variable. +*) + val remove_prefix : string -> string +(** Used in code generation so that this + mangling is opaque to the interfaces (e.g. cmdstan) +*) + val add_prefix_to_kwrds : string -> string diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index daf129b0b1..743db3a4b8 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -100,7 +100,7 @@ let options = exit 0 ) , " Display stanc version number" ) ; ( "--name" - , Arg.Set_string Semantic_check.model_name + , Arg.Set_string Typechecker.model_name , " Take a string to set the model name (default = \ \"$model_filename_model\")" ) ; ( "--O" @@ -114,10 +114,10 @@ let options = , Arg.Set print_model_cpp , " If set, output the generated C++ Stan model class to stdout." ) ; ( "--allow-undefined" - , Arg.Clear Semantic_check.check_that_all_functions_have_definition + , Arg.Clear Typechecker.check_that_all_functions_have_definition , " Do not fail if a function is declared but not defined" ) ; ( "--allow_undefined" - , Arg.Clear Semantic_check.check_that_all_functions_have_definition + , Arg.Clear Typechecker.check_that_all_functions_have_definition , " Deprecated. Same as --allow-undefined." ) ; ( "--include-paths" , Arg.String @@ -256,12 +256,12 @@ let main () = exit 0 ) ; (* Just translate a stan program *) if !model_file = "" then model_file_err () ; - if !Semantic_check.model_name = "" then - Semantic_check.model_name := + if !Typechecker.model_name = "" then + Typechecker.model_name := mangle (remove_dotstan List.(hd_exn (rev (String.split !model_file ~on:'/')))) ^ "_model" - else Semantic_check.model_name := mangle !Semantic_check.model_name ; + else Typechecker.model_name := mangle !Typechecker.model_name ; if !output_file = "" then output_file := remove_dotstan !model_file ^ ".hpp" ; use_file !model_file diff --git a/src/stancjs/stancjs.ml b/src/stancjs/stancjs.ml index 91c31cdbf6..78e3a70894 100644 --- a/src/stancjs/stancjs.ml +++ b/src/stancjs/stancjs.ml @@ -20,8 +20,8 @@ let warn_uninitialized_msgs Set.Poly.(to_list (map filtered_uninit_vars ~f:show_var_info)) let stan2cpp model_name model_string is_flag_set = - Semantic_check.model_name := model_name ; - Semantic_check.check_that_all_functions_have_definition := + Typechecker.model_name := model_name ; + Typechecker.check_that_all_functions_have_definition := not (is_flag_set "allow_undefined" || is_flag_set "allow-undefined") ; Transform_Mir.use_opencl := is_flag_set "use-opencl" ; Stan_math_code_gen.standalone_functions := is_flag_set "standalone-functions" ; @@ -40,15 +40,13 @@ let stan2cpp model_name model_string is_flag_set = let result = ast >>= fun ast -> - let typed_ast_and_warnings = - Semantic_check.semantic_check_program ast - |> Result.map_error ~f:(fun errs -> - Errors.Semantic_error (List.hd_exn errs) ) + let typed_ast = + Typechecker.check_program ast + |> Result.map_error ~f:(fun e -> Errors.Semantic_error e) in - typed_ast_and_warnings - >>| fun typed_ast_and_warnings -> - let typed_ast, semantic_warnings = typed_ast_and_warnings in - let warnings = parser_warnings @ semantic_warnings in + typed_ast + >>| fun (typed_ast, type_warnings) -> + let warnings = parser_warnings @ type_warnings in if is_flag_set "info" then r.return (Result.Ok (Info.info typed_ast), warnings, []) ; if is_flag_set "print-canonical" then diff --git a/test/integration/bad/functions-bad2.stan b/test/integration/bad/functions-bad2.stan index 110a0051d0..bcf9675761 100644 --- a/test/integration/bad/functions-bad2.stan +++ b/test/integration/bad/functions-bad2.stan @@ -1,4 +1,10 @@ functions { + int foobar(int x); + + int foobar(int x){ + return 1; + } + real barfoo(real x); // error not defining barfoo } diff --git a/test/integration/bad/shadow.stan b/test/integration/bad/shadow.stan new file mode 100644 index 0000000000..c887884295 --- /dev/null +++ b/test/integration/bad/shadow.stan @@ -0,0 +1,5 @@ +functions { + int log10(int x) { + return 1; + } +} \ No newline at end of file diff --git a/test/integration/bad/stanc.expected b/test/integration/bad/stanc.expected index 2db41e02f5..94043d6dd0 100644 --- a/test/integration/bad/stanc.expected +++ b/test/integration/bad/stanc.expected @@ -1099,16 +1099,17 @@ Semantic error in 'functions-bad19.stan', line 4, column 2 to line 6, column 3: Function 'my_fun3' has already been declared to have type (data real) => real $ ../../../../install/default/bin/stanc functions-bad2.stan -Semantic error in 'functions-bad2.stan', line 2, column 2 to column 22: +Semantic error in 'functions-bad2.stan', line 8, column 7 to column 13: ------------------------------------------------- - 1: functions { - 2: real barfoo(real x); - ^ - 3: // error not defining barfoo - 4: } + 6: } + 7: + 8: real barfoo(real x); + ^ + 9: // error not defining barfoo + 10: } ------------------------------------------------- -Some function is declared without specifying a definition. +Function is declared without specifying a definition. $ ../../../../install/default/bin/stanc functions-bad20.stan Semantic error in 'functions-bad20.stan', line 4, column 2 to line 6, column 3: ------------------------------------------------- @@ -2041,6 +2042,17 @@ Semantic error in 'row_vector_expr_bad2.stan', line 5, column 25 to column 26: ------------------------------------------------- Row_vector expression must have all int or real entries. Found type row_vector. + $ ../../../../install/default/bin/stanc shadow.stan +Semantic error in 'shadow.stan', line 2, column 7 to column 12: + ------------------------------------------------- + 1: functions { + 2: int log10(int x) { + ^ + 3: return 1; + 4: } + ------------------------------------------------- + +Identifier 'log10' clashes with Stan Math library function. $ ../../../../install/default/bin/stanc signature_function_known.stan Warning in 'signature_function_known.stan', line 8, column 2: increment_log_prob(...); is deprecated and will be removed in the future. Use target += ...; instead. Semantic error in 'signature_function_known.stan', line 8, column 21 to column 50: diff --git a/test/integration/tfp/dune b/test/integration/tfp/dune index 861c933fa3..a793f9cc3b 100644 --- a/test/integration/tfp/dune +++ b/test/integration/tfp/dune @@ -36,12 +36,18 @@ (deps (package stanc) (:stanfiles ./stan_models/python_kwrds.stan)) (action (with-stdout-to %{targets} - (run %{bin:stan2tfp} %{stanfiles})))) + (with-stderr-to "tfp.output" + (run %{bin:stan2tfp} %{stanfiles}))))) (alias (name runtest) (action (diff ./tfp_models/python_kwrds.py python_kwrds.py.output))) + +(alias + (name runtest) + (action (diff ./tfp_models/tfp.expected tfp.output))) + (rule (targets transformed_data.py.output) (deps (package stanc) (:stanfiles ./stan_models/transformed_data.stan)) @@ -50,5 +56,5 @@ (run %{bin:stan2tfp} %{stanfiles})))) (alias - (name runtest) + (name runtest) (action (diff ./tfp_models/transformed_data.py transformed_data.py.output))) diff --git a/test/integration/tfp/tfp_models/tfp.expected b/test/integration/tfp/tfp_models/tfp.expected new file mode 100644 index 0000000000..1f544968d2 --- /dev/null +++ b/test/integration/tfp/tfp_models/tfp.expected @@ -0,0 +1,26 @@ +Warning in 'stan_models/python_kwrds.stan', line 13, column 12: Found int division: + lambda / 3 +Values will be rounded towards zero. If rounding is not desired you can write +the division as + lambda / 3.0 +If rounding is intended please use the integer division operator %/%. +Identifier finally is a reserved word in python, renamed to finally__ +Identifier assert is a reserved word in python, renamed to assert__ +Identifier lambda is a reserved word in python, renamed to lambda__ +Identifier yield is a reserved word in python, renamed to yield__ +Identifier await is a reserved word in python, renamed to await__ +Identifier finally is a reserved word in python, renamed to finally__ +Identifier finally is a reserved word in python, renamed to finally__ +Identifier assert is a reserved word in python, renamed to assert__ +Identifier yield is a reserved word in python, renamed to yield__ +Identifier lambda is a reserved word in python, renamed to lambda__ +Identifier finally is a reserved word in python, renamed to finally__ +Identifier lambda is a reserved word in python, renamed to lambda__ +Identifier await is a reserved word in python, renamed to await__ +Identifier finally is a reserved word in python, renamed to finally__ +Identifier assert is a reserved word in python, renamed to assert__ +Identifier assert is a reserved word in python, renamed to assert__ +Identifier finally is a reserved word in python, renamed to finally__ +Identifier assert is a reserved word in python, renamed to assert__ +Identifier assert is a reserved word in python, renamed to assert__ +Identifier lambda is a reserved word in python, renamed to lambda__ \ No newline at end of file diff --git a/test/stancjs/stancjs.expected b/test/stancjs/stancjs.expected index 923427008a..44cae224c7 100644 --- a/test/stancjs/stancjs.expected +++ b/test/stancjs/stancjs.expected @@ -1,6 +1,6 @@ $ node allow-undefined.js -Semantic error in 'string', line 3, column 4 to column 20: -Some function is declared without specifying a definition. +Semantic error in 'string', line 3, column 8 to column 11: +Function is declared without specifying a definition. $ node auto-format.js parameters { @@ -80,8 +80,8 @@ $ node info.js "included_files": [ ] } $ node optimization.js -Semantic error in 'string', line 3, column 4 to column 20: -Some function is declared without specifying a definition. +Semantic error in 'string', line 3, column 8 to column 11: +Function is declared without specifying a definition. $ node pedantic.js ["Warning in 'string', line 7, column 17: Argument 10000 suggests there may be parameters that are not unit scale; consider rescaling with a multiplier (see manual section 22.12).","Warning: The parameter k was declared but was not used in the density calculation."] diff --git a/test/unit/Parse_tests.ml b/test/unit/Parse_tests.ml index 0c597ff817..68b127e81d 100644 --- a/test/unit/Parse_tests.ml +++ b/test/unit/Parse_tests.ml @@ -4,7 +4,7 @@ open Frontend let print_ast_of_string s = let ast = Frontend_utils.untyped_ast_of_string s - |> Result.map_error ~f:Middle.Errors.to_string + |> Result.map_error ~f:Errors.to_string |> Result.ok_or_failwith in print_s [%sexp (ast : Ast.untyped_program)] diff --git a/test/unit/Pedantic_analysis.ml b/test/unit/Pedantic_analysis.ml index a785035b10..dd88d3c098 100644 --- a/test/unit/Pedantic_analysis.ml +++ b/test/unit/Pedantic_analysis.ml @@ -28,7 +28,7 @@ let sigma_example = let print_warn_pedantic p = p |> warn_pedantic - |> Fmt.strf "%a" (Middle.Warnings.pp_warnings ?printed_filename:None) + |> Fmt.strf "%a" (Warnings.pp_warnings ?printed_filename:None) |> print_endline let%expect_test "Unbounded sigma warning" =