diff --git a/.gitignore b/.gitignore index ea8c4bf..2f7896d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -/target +target/ diff --git a/Cargo.lock b/Cargo.lock index acdb4d0..d262154 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -711,11 +711,21 @@ dependencies = [ "process_control", "tempfile", "thiserror 2.0.18", + "thrust-macros", "tracing", "tracing-subscriber", "ui_test", ] +[[package]] +name = "thrust-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing" version = "0.1.44" diff --git a/Cargo.toml b/Cargo.toml index 2cc1907..72b9bd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,7 @@ +[workspace] +members = [".", "thrust-macros"] +resolver = "2" + [package] name = "thrust" version = "0.1.0" @@ -18,6 +22,11 @@ name = "ui" harness = false [dependencies] +# Workaround until thrust supports invocation via cargo: adding thrust-macros as +# a dependency ensures cargo builds (and rebuilds) it alongside thrust. It is not +# used as a proc-macro here; the dylib is loaded at runtime via --extern. +thrust-macros = { path = "thrust-macros" } + anyhow = "1.0.102" pretty = { version = "0.12.5", features = ["termcolor"] } tempfile = "3.27.0" diff --git a/src/analyze.rs b/src/analyze.rs index 85ba3b2..523b56d 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -39,10 +39,16 @@ pub fn mir_borrowck_skip_formula_fn( local_def_id: rustc_span::def_id::LocalDefId, ) -> rustc_middle::query::queries::mir_borrowck::ProvidedValue { // TODO: unify impl with local_def::Analyzer + // if the def is closure defined in formula_fn + let root_def_id = tcx.typeck_root_def_id(local_def_id.to_def_id()); let is_annotated_as_formula_fn = tcx .get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::formula_fn_path()) .next() - .is_some(); + .is_some() + || tcx + .get_attrs_by_path(root_def_id, &analyze::annot::formula_fn_path()) + .next() + .is_some(); if is_annotated_as_formula_fn { tracing::debug!(?local_def_id, "skipping borrow check for formula fn"); diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 51e21db..b3e9208 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -100,6 +100,19 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { return; } + // skip analysis if the def is closure defined in skipped def + if let Some(root_local_def_id) = self + .tcx + .typeck_root_def_id(local_def_id.to_def_id()) + .as_local() + { + if root_local_def_id != local_def_id && self.skip_analysis.contains(&root_local_def_id) + { + self.skip_analysis.insert(local_def_id); + return; + } + } + let target_def_id = if analyzer.is_annotated_as_extern_spec_fn() { analyzer.extern_spec_fn_target_def_id() } else { diff --git a/src/main.rs b/src/main.rs index 391424e..1a26967 100644 --- a/src/main.rs +++ b/src/main.rs @@ -66,8 +66,36 @@ impl Callbacks for CompilerCalls { } } +fn thrust_macros_path() -> Option { + let exe = std::env::current_exe().ok()?; + let dir = exe.parent()?; + // When thrust-macros is a cargo dependency it lands in deps/ with a hash + // suffix (e.g. libthrust_macros-.so), so check both locations. + let search_dirs = [dir.to_path_buf(), dir.join("deps")]; + for search_dir in &search_dirs { + for ext in ["so", "dylib", "dll"] { + // First try the exact name (e.g. when built standalone). + let candidate = search_dir.join(format!("libthrust_macros.{ext}")); + if candidate.exists() { + return Some(candidate); + } + // Then scan for libthrust_macros-.. + if let Ok(entries) = std::fs::read_dir(search_dir) { + for entry in entries.flatten() { + let name = entry.file_name(); + let name = name.to_string_lossy(); + if name.starts_with("libthrust_macros-") && name.ends_with(ext) { + return Some(entry.path()); + } + } + } + } + } + None +} + pub fn main() { - let args = std::env::args().collect::>(); + let mut args = std::env::args().collect::>(); use tracing_subscriber::{filter::EnvFilter, prelude::*}; tracing_subscriber::registry() @@ -80,6 +108,14 @@ pub fn main() { ) .init(); + if let Some(path) = thrust_macros_path() { + args.push("--extern".to_owned()); + args.push(format!("thrust_macros={}", path.display())); + tracing::debug!("linking thrust_macros from {}", path.display()); + } else { + tracing::warn!("could not locate thrust_macros library"); + } + let code = rustc_driver::catch_with_exit_code(|| RunCompiler::new(&args, &mut CompilerCalls {}).run()); std::process::exit(code); diff --git a/std.rs b/std.rs index 3318656..04b93a3 100644 --- a/std.rs +++ b/std.rs @@ -244,84 +244,107 @@ mod thrust_models { } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures(result == )] -fn _extern_spec_box_new(x: T) -> Box where T: thrust_models::Model { +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(result == thrust_models::model::Box::new(x))] +fn _extern_spec_box_new(x: T) -> Box where T: thrust_models::Model, T::Ty: PartialEq { Box::new(x) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures(*x == ^y && *y == ^x)] -fn _extern_spec_std_mem_swap(x: &mut T, y: &mut T) where T: thrust_models::Model { +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(result == (x == y))] +fn _extern_spec_box_partialeq_eq(x: &Box, y: &Box) -> bool + where T: thrust_models::Model + PartialEq, T::Ty: PartialEq +{ + as PartialEq>::eq(x, y) +} + +#[thrust::extern_spec_fn] +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(*x == !y && *y == !x)] +fn _extern_spec_std_mem_swap(x: &mut T, y: &mut T) where T: thrust_models::Model, T::Ty: PartialEq { std::mem::swap(x, y) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures(^dest == src && result == *dest)] -fn _extern_spec_std_mem_replace(dest: &mut T, src: T) -> T where T: thrust_models::Model { +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(!dest == src && result == *dest)] +fn _extern_spec_std_mem_replace(dest: &mut T, src: T) -> T where T: thrust_models::Model, T::Ty: PartialEq { std::mem::replace(dest, src) } #[thrust::extern_spec_fn] -#[thrust::requires(opt != std::option::Option::::None())] -#[thrust::ensures(std::option::Option::::Some(result) == opt)] -fn _extern_spec_option_unwrap(opt: Option) -> T where T: thrust_models::Model { +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(result == (x == y))] +fn _extern_spec_option_partialeq_eq(x: &Option, y: &Option) -> bool + where T: thrust_models::Model + PartialEq, T::Ty: PartialEq +{ + as PartialEq>::eq(x, y) +} + +#[thrust::extern_spec_fn] +#[thrust_macros::requires(opt != None)] +#[thrust_macros::ensures(Some(result) == opt)] +fn _extern_spec_option_unwrap(opt: Option) -> T where T: thrust_models::Model, T::Ty: PartialEq { Option::unwrap(opt) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures( - (*opt == std::option::Option::::None() && result == true) - || (*opt != std::option::Option::::None() && result == false) +#[thrust_macros::requires(true)] +#[thrust_macros::ensures( + (*opt == None && result == true) + || (*opt != None && result == false) )] -fn _extern_spec_option_is_none(opt: &Option) -> bool where T: thrust_models::Model { +fn _extern_spec_option_is_none(opt: &Option) -> bool where T: thrust_models::Model, T::Ty: PartialEq { Option::is_none(opt) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures( - (*opt == std::option::Option::::None() && result == false) - || (*opt != std::option::Option::::None() && result == true) +#[thrust_macros::requires(true)] +#[thrust_macros::ensures( + (*opt == None && result == false) + || (*opt != None && result == true) )] -fn _extern_spec_option_is_some(opt: &Option) -> bool where T: thrust_models::Model { +fn _extern_spec_option_is_some(opt: &Option) -> bool where T: thrust_models::Model, T::Ty: PartialEq { Option::is_some(opt) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures( - (opt != std::option::Option::::None() && std::option::Option::::Some(result) == opt) - || (opt == std::option::Option::::None() && result == default) +#[thrust_macros::requires(true)] +#[thrust_macros::ensures( + (opt != None && Some(result) == opt) + || (opt == None && result == default) )] -fn _extern_spec_option_unwrap_or(opt: Option, default: T) -> T where T: thrust_models::Model { +fn _extern_spec_option_unwrap_or(opt: Option, default: T) -> T where T: thrust_models::Model, T::Ty: PartialEq { Option::unwrap_or(opt, default) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures( - (exists x:T0. opt == std::option::Option::::Some(x) && result == std::result::Result::::Ok(x)) - || (opt == std::option::Option::::None() && result == std::result::Result::::Err(err)) +#[thrust_macros::requires(true)] +#[thrust_macros::ensures( + (thrust_models::exists(|x| opt == Some(x) && result == Ok(x))) + || (opt == None && result == Err(err)) )] -fn _extern_spec_option_ok_or(opt: Option, err: E) -> Result where T: thrust_models::Model, E: thrust_models::Model { +fn _extern_spec_option_ok_or(opt: Option, err: E) -> Result + where T: thrust_models::Model, T::Ty: PartialEq, + E: thrust_models::Model, E::Ty: PartialEq, +{ Option::ok_or(opt, err) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures(^opt == std::option::Option::::None() && result == *opt)] -fn _extern_spec_option_take(opt: &mut Option) -> Option where T: thrust_models::Model { +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(!opt == None && result == *opt)] +fn _extern_spec_option_take(opt: &mut Option) -> Option where T: thrust_models::Model, T::Ty: PartialEq { Option::take(opt) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures(^opt == std::option::Option::::Some(src) && result == *opt)] -fn _extern_spec_option_replace(opt: &mut Option, src: T) -> Option where T: thrust_models::Model { +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(!opt == Some(src) && result == *opt)] +fn _extern_spec_option_replace(opt: &mut Option, src: T) -> Option + where T: thrust_models::Model, T::Ty: PartialEq +{ Option::replace(opt, src) } @@ -336,54 +359,78 @@ fn _extern_spec_option_as_ref(opt: &Option) -> Option<&T> where T: thrust_ } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures( - (exists x1:T0, x2:T0. - *opt == std::option::Option::::Some(x1) && - ^opt == std::option::Option::::Some(x2) && - result == std::option::Option::<&mut T0>::Some() +#[thrust_macros::requires(true)] +#[thrust_macros::ensures( + thrust_models::exists(|x1, x2| + *opt == Some(x1) && + !opt == Some(x2) && + result == Some(thrust_models::model::Mut::new(x1, x2)) ) || ( - *opt == std::option::Option::::None() && - ^opt == std::option::Option::::None() && - result == std::option::Option::<&mut T0>::None() + *opt == None && + !opt == None && + result == None ) )] -fn _extern_spec_option_as_mut(opt: &mut Option) -> Option<&mut T> where T: thrust_models::Model { +fn _extern_spec_option_as_mut(opt: &mut Option) -> Option<&mut T> + where T: thrust_models::Model, T::Ty: PartialEq +{ Option::as_mut(opt) } #[thrust::extern_spec_fn] -#[thrust::requires(exists x:T0. res == std::result::Result::::Ok(x))] -#[thrust::ensures(std::result::Result::::Ok(result) == res)] -fn _extern_spec_result_unwrap(res: Result) -> T where T: thrust_models::Model { +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(result == (x == y))] +fn _extern_spec_result_partialeq_eq(x: &Result, y: &Result) -> bool + where T: thrust_models::Model + PartialEq, T::Ty: PartialEq, + E: thrust_models::Model + PartialEq, E::Ty: PartialEq, +{ + as PartialEq>::eq(x, y) +} + +#[thrust::extern_spec_fn] +#[thrust_macros::requires(thrust_models::exists(|x| res == Ok(x)))] +#[thrust_macros::ensures(Ok(result) == res)] +fn _extern_spec_result_unwrap(res: Result) -> T + where T: thrust_models::Model, T::Ty: PartialEq, + E: thrust_models::Model, E::Ty: PartialEq, +{ Result::unwrap(res) } #[thrust::extern_spec_fn] -#[thrust::requires(exists x:T1. res == std::result::Result::::Err(x))] -#[thrust::ensures(std::result::Result::::Err(result) == res)] -fn _extern_spec_result_unwrap_err(res: Result) -> E where T: thrust_models::Model, E: thrust_models::Model { +#[thrust_macros::requires(thrust_models::exists(|x| res == Err(x)))] +#[thrust_macros::ensures(Err(result) == res)] +fn _extern_spec_result_unwrap_err(res: Result) -> E + where T: thrust_models::Model, T::Ty: PartialEq, + E: thrust_models::Model, E::Ty: PartialEq, +{ Result::unwrap_err(res) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures( - (exists x:T0. res == std::result::Result::::Ok(x) && result == std::option::Option::::Some(x)) - || (exists x:T1. res == std::result::Result::::Err(x) && result == std::option::Option::::None()) +#[thrust_macros::requires(true)] +#[thrust_macros::ensures( + thrust_models::exists(|x| res == Ok(x) && result == Some(x)) + || thrust_models::exists(|x| res == Err(x) && result == None) )] -fn _extern_spec_result_ok(res: Result) -> Option where T: thrust_models::Model, E: thrust_models::Model { +fn _extern_spec_result_ok(res: Result) -> Option + where T: thrust_models::Model, T::Ty: PartialEq, + E: thrust_models::Model, E::Ty: PartialEq, +{ Result::ok(res) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures( - (exists x:T0. res == std::result::Result::::Ok(x) && result == std::option::Option::::None()) - || (exists x:T1. res == std::result::Result::::Err(x) && result == std::option::Option::::Some(x)) +#[thrust_macros::requires(true)] +#[thrust_macros::ensures( + thrust_models::exists(|x| res == Ok(x) && result == None) + || thrust_models::exists(|x| res == Err(x) && result == Some(x)) )] -fn _extern_spec_result_err(res: Result) -> Option where T: thrust_models::Model, E: thrust_models::Model { +fn _extern_spec_result_err(res: Result) -> Option + where T: thrust_models::Model, T::Ty: PartialEq, + E: thrust_models::Model, E::Ty: PartialEq, +{ Result::err(res) } @@ -408,15 +455,15 @@ fn _extern_spec_result_is_err(res: &Result) -> bool where T: thrust_ } #[thrust::extern_spec_fn] -#[thrust::requires(true)] // TODO: require x != i32::MIN -#[thrust::ensures(result >= 0 && (result == x || result == -x))] +#[thrust_macros::requires(true)] // TODO: require x != i32::MIN +#[thrust_macros::ensures(result >= 0 && (result == x || result == -x))] fn _extern_spec_i32_abs(x: i32) -> i32 { i32::abs(x) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures( +#[thrust_macros::requires(true)] +#[thrust_macros::ensures( (x >= y && result == (x - y)) || (x < y && result == (y - x)) )] @@ -425,22 +472,22 @@ fn _extern_spec_i32_abs_diff(x: i32, y: i32) -> u32 { } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures((x == 0 && result == 0) || (x > 0 && result == 1) || (x < 0 && result == -1))] +#[thrust_macros::requires(true)] +#[thrust_macros::ensures((x == 0 && result == 0) || (x > 0 && result == 1) || (x < 0 && result == -1))] fn _extern_spec_i32_signum(x: i32) -> i32 { i32::signum(x) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures((x < 0 && result == false) || (x >= 0 && result == true))] +#[thrust_macros::requires(true)] +#[thrust_macros::ensures((x < 0 && result == false) || (x >= 0 && result == true))] fn _extern_spec_i32_is_positive(x: i32) -> bool { i32::is_positive(x) } #[thrust::extern_spec_fn] -#[thrust::requires(true)] -#[thrust::ensures((x <= 0 && result == true) || (x > 0 && result == false))] +#[thrust_macros::requires(true)] +#[thrust_macros::ensures((x <= 0 && result == true) || (x > 0 && result == false))] fn _extern_spec_i32_is_negative(x: i32) -> bool { i32::is_negative(x) } diff --git a/tests/ui/fail/annot_enum_simple.rs b/tests/ui/fail/annot_enum_simple.rs index b0f9077..f97225a 100644 --- a/tests/ui/fail/annot_enum_simple.rs +++ b/tests/ui/fail/annot_enum_simple.rs @@ -1,12 +1,17 @@ //@error-in-other-file: Unsat +#[derive(PartialEq)] pub enum X { A(i64), B(bool), } -#[thrust::requires(x == X::A(1))] -#[thrust::ensures(true)] +impl thrust_models::Model for X { + type Ty = X; +} + +#[thrust_macros::requires(x == X::A(1))] +#[thrust_macros::ensures(true)] fn test(x: X) { if let X::A(i) = x { assert!(i == 2); diff --git a/tests/ui/fail/annot_exists.rs b/tests/ui/fail/annot_exists.rs index 2558b59..496cfae 100644 --- a/tests/ui/fail/annot_exists.rs +++ b/tests/ui/fail/annot_exists.rs @@ -2,12 +2,14 @@ //@compile-flags: -C debug-assertions=off //@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper +use thrust_models::exists; + #[thrust::trusted] #[thrust::callable] fn rand() -> i32 { unimplemented!() } -#[thrust::requires(true)] -#[thrust::ensures(exists x:int. result == 2 * x)] +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(exists(|x: i32| result == 2 * x))] fn f() -> i32 { let x = rand(); x + x + x diff --git a/tests/ui/fail/annot_mut_term.rs b/tests/ui/fail/annot_mut_term.rs index f1037dd..c7035b1 100644 --- a/tests/ui/fail/annot_mut_term.rs +++ b/tests/ui/fail/annot_mut_term.rs @@ -1,7 +1,9 @@ //@error-in-other-file: Unsat -#[thrust::requires(true)] -#[thrust::ensures(x == <*x, y>)] +use thrust_models::model::Mut; + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(x == Mut::new(*x, y))] fn f(x: &mut i64, y: i64) { *x = y; } diff --git a/tests/ui/fail/tuple_annot.rs b/tests/ui/fail/tuple_annot.rs index 9e5c4e1..f283236 100644 --- a/tests/ui/fail/tuple_annot.rs +++ b/tests/ui/fail/tuple_annot.rs @@ -1,8 +1,8 @@ //@error-in-other-file: Unsat //@compile-flags: -C debug-assertions=off -#[thrust::requires(true)] -#[thrust::ensures(^x == ((*x).1, (*x).1))] +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(!x == ((*x).1, (*x).1))] fn swap_tuple(x: &mut (i32, i32)) { let v = *x; *x = (v.1, v.0); diff --git a/tests/ui/pass/annot_enum_simple.rs b/tests/ui/pass/annot_enum_simple.rs index 70eed4c..9ed6184 100644 --- a/tests/ui/pass/annot_enum_simple.rs +++ b/tests/ui/pass/annot_enum_simple.rs @@ -1,12 +1,17 @@ //@check-pass +#[derive(PartialEq)] pub enum X { A(i64), B(bool), } -#[thrust::requires(x == X::A(1))] -#[thrust::ensures(true)] +impl thrust_models::Model for X { + type Ty = X; +} + +#[thrust_macros::requires(x == X::A(1))] +#[thrust_macros::ensures(true)] fn test(x: X) { if let X::A(i) = x { assert!(i == 1); diff --git a/tests/ui/pass/annot_exists.rs b/tests/ui/pass/annot_exists.rs index 1496c7f..3ab0003 100644 --- a/tests/ui/pass/annot_exists.rs +++ b/tests/ui/pass/annot_exists.rs @@ -2,12 +2,14 @@ //@compile-flags: -C debug-assertions=off //@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper +use thrust_models::exists; + #[thrust::trusted] #[thrust::callable] fn rand() -> i32 { unimplemented!() } -#[thrust::requires(true)] -#[thrust::ensures(exists x:int. result == 2 * x)] +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(exists(|x: i32| result == 2 * x))] fn f() -> i32 { let x = rand(); x + x diff --git a/tests/ui/pass/annot_mut_term.rs b/tests/ui/pass/annot_mut_term.rs index 172a1df..60aa17c 100644 --- a/tests/ui/pass/annot_mut_term.rs +++ b/tests/ui/pass/annot_mut_term.rs @@ -1,7 +1,9 @@ //@check-pass -#[thrust::requires(true)] -#[thrust::ensures(x == <*x, y>)] +use thrust_models::model::Mut; + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(x == Mut::new(*x, y))] fn f(x: &mut i64, y: i64) { *x = y; } diff --git a/tests/ui/pass/take_max_annot.rs b/tests/ui/pass/take_max_annot.rs index 1f467a6..a752e2c 100644 --- a/tests/ui/pass/take_max_annot.rs +++ b/tests/ui/pass/take_max_annot.rs @@ -1,9 +1,8 @@ //@check-pass //@compile-flags: -C debug-assertions=off -#[thrust::requires(true)] -#[thrust::ensures(true)] #[thrust::trusted] +#[thrust::callable] fn rand() -> i64 { unimplemented!() } #[thrust::requires(true)] diff --git a/tests/ui/pass/tuple_annot.rs b/tests/ui/pass/tuple_annot.rs index 530706f..4ab7274 100644 --- a/tests/ui/pass/tuple_annot.rs +++ b/tests/ui/pass/tuple_annot.rs @@ -1,8 +1,8 @@ //@check-pass //@compile-flags: -C debug-assertions=off -#[thrust::requires(true)] -#[thrust::ensures(^x == ((*x).1, (*x).0))] +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(!x == ((*x).1, (*x).0))] fn swap_tuple(x: &mut (i32, i32)) { let v = *x; *x = (v.1, v.0); diff --git a/thrust-macros/Cargo.lock b/thrust-macros/Cargo.lock new file mode 100644 index 0000000..b4a86db --- /dev/null +++ b/thrust-macros/Cargo.lock @@ -0,0 +1,47 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "proc-macro2" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "002a1b3dbf967edfafc32655d0f377ab0bb7b994aa1d32c8cc7e9b8bf3ebb8f0" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thrust-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" diff --git a/thrust-macros/Cargo.toml b/thrust-macros/Cargo.toml new file mode 100644 index 0000000..b85219f --- /dev/null +++ b/thrust-macros/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "thrust-macros" +version = "0.1.0" +authors = ["coord_e "] +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1" +quote = "=1.0.36" +syn = { version = "2", features = ["full"] } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs new file mode 100644 index 0000000..8370c1c --- /dev/null +++ b/thrust-macros/src/lib.rs @@ -0,0 +1,313 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + parse_macro_input, FnArg, GenericParam, Generics, ItemFn, Pat, ReturnType, Type, + TypeParamBound, WherePredicate, +}; + +/// Transforms a function annotated with `#[requires(req)]` and `#[ensures(ens)]` into +/// an internal multi-function representation for the Thrust analysis backend. +/// +/// This macro must be paired with `#[ensures(...)]` on the same function. +/// +/// # Generated output +/// +/// For `#[requires(req)] #[ensures(ens)] fn f(a: A) -> R`: +/// +/// - The original function `f` (unchanged, minus the two thrust attrs) +/// - `_thrust_requires_f(a: ::Ty) -> bool` with body `req` +/// - `_thrust_ensures_f(result: ::Ty, a: ::Ty) -> bool` with body `ens` +/// - `_thrust_extern_spec_f(a: A) -> R` referencing the above via path statements +#[proc_macro_attribute] +pub fn requires(attr: TokenStream, item: TokenStream) -> TokenStream { + let req_expr = TokenStream2::from(attr); + let mut func = parse_macro_input!(item as ItemFn); + + // Find and extract the paired #[ensures(...)] attribute + let ensures_pos = func.attrs.iter().position(|a| { + a.path() + .segments + .last() + .map_or(false, |s| s.ident == "ensures") + }); + + let ens_expr = match ensures_pos { + Some(idx) => { + let attr = func.attrs.remove(idx); + match attr.parse_args::() { + Ok(ts) => ts, + Err(e) => return e.to_compile_error().into(), + } + } + None => { + return syn::Error::new_spanned( + &func.sig.ident, + "#[requires] must be paired with #[ensures]", + ) + .to_compile_error() + .into(); + } + }; + + // Strip any stray requires attrs (defensive — they'd cause infinite recursion) + func.attrs.retain(|a| { + a.path() + .segments + .last() + .map_or(true, |s| s.ident != "requires") + }); + + let is_extern_spec = func.attrs.iter().any(|a| { + a.path() + .segments + .last() + .map_or(false, |s| s.ident == "extern_spec_fn") + }); + + if is_extern_spec { + transform_extern_spec(func, req_expr, ens_expr) + .unwrap_or_else(|e| e.to_compile_error().into()) + } else { + transform(func, req_expr, ens_expr).unwrap_or_else(|e| e.to_compile_error().into()) + } +} + +/// Pass-through no-op. When `#[ensures]` is consumed by `#[requires]`, this is never +/// reached. It exists so that `#[ensures]` is a recognized proc macro attribute and does +/// not produce an "unknown attribute" error if applied in isolation. +#[proc_macro_attribute] +pub fn ensures(_attr: TokenStream, item: TokenStream) -> TokenStream { + item +} + +fn transform_extern_spec( + mut func: ItemFn, + req_expr: TokenStream2, + ens_expr: TokenStream2, +) -> syn::Result { + let name = &func.sig.ident; + let requires_name = format_ident!("_thrust_requires_{}", name); + let ensures_name = format_ident!("_thrust_ensures_{}", name); + + let generics = &func.sig.generics; + + let def_generics = generic_params_tokens(generics); + let turbofish = generic_turbofish(generics); + let model_preds = model_where_predicates(generics); + let extended_where = extended_where_clause(generics, &model_preds); + + // Parameters with types replaced by ::Ty + let model_ty_params = fn_params_with_model_ty(&func.sig.inputs); + + // Return type for ensures helper (wrapped in ::Ty) + let ret_model_ty: Type = match &func.sig.output { + ReturnType::Default => syn::parse_quote!(<() as thrust_models::Model>::Ty), + ReturnType::Type(_, ty) => syn::parse_quote!(<#ty as thrust_models::Model>::Ty), + }; + + // Prepend requires_path/ensures_path statements to the original block + let orig_stmts = func.block.stmts.clone(); + func.block = syn::parse_quote!({ + #[thrust::requires_path] + #requires_name #turbofish; + #[thrust::ensures_path] + #ensures_name #turbofish; + #(#orig_stmts)* + }); + + let output = quote! { + #[allow(unused_variables)] + #[thrust::formula_fn] + fn #requires_name #def_generics(#model_ty_params) -> bool #extended_where { + #req_expr + } + + #[allow(unused_variables)] + #[thrust::formula_fn] + fn #ensures_name #def_generics(result: #ret_model_ty, #model_ty_params) -> bool #extended_where { + #ens_expr + } + + #[allow(path_statements)] + #func + }; + + Ok(output.into()) +} + +fn transform( + func: ItemFn, + req_expr: TokenStream2, + ens_expr: TokenStream2, +) -> syn::Result { + let name = &func.sig.ident; + let requires_name = format_ident!("_thrust_requires_{}", name); + let ensures_name = format_ident!("_thrust_ensures_{}", name); + let extern_spec_name = format_ident!("_thrust_extern_spec_{}", name); + + let generics = &func.sig.generics; + + // for function definitions (params only, no where clause) + let def_generics = generic_params_tokens(generics); + + // :: for turbofish (empty if no params) + let turbofish = generic_turbofish(generics); + + // Additional where predicates: T: thrust_models::Model for qualifying type params + let model_preds = model_where_predicates(generics); + + // Full where clause combining original predicates and model predicates + let extended_where = extended_where_clause(generics, &model_preds); + + // Parameters with types replaced by ::Ty + let model_ty_params = fn_params_with_model_ty(&func.sig.inputs); + + // Argument idents for function call expressions + let call_args = fn_arg_idents(&func.sig.inputs); + + // Return type for requires/ensures helper (wrapped in ::Ty) + let ret_model_ty: Type = match &func.sig.output { + ReturnType::Default => syn::parse_quote!(<() as thrust_models::Model>::Ty), + ReturnType::Type(_, ty) => syn::parse_quote!(<#ty as thrust_models::Model>::Ty), + }; + + let orig_inputs = &func.sig.inputs; + let orig_output = &func.sig.output; + + let output = quote! { + #func + + #[allow(unused_variables)] + #[thrust::formula_fn] + fn #requires_name #def_generics(#model_ty_params) -> bool #extended_where { + #req_expr + } + + #[allow(unused_variables)] + #[thrust::formula_fn] + fn #ensures_name #def_generics(result: #ret_model_ty, #model_ty_params) -> bool #extended_where { + #ens_expr + } + + #[thrust::extern_spec_fn] + #[allow(path_statements)] + fn #extern_spec_name #def_generics(#orig_inputs) #orig_output #extended_where { + #[thrust::requires_path] + #requires_name #turbofish; + + #[thrust::ensures_path] + #ensures_name #turbofish; + + #name #turbofish(#call_args) + } + }; + + Ok(output.into()) +} + +/// Returns `` — the generic param list for function definitions, +/// without a where clause. +fn generic_params_tokens(generics: &Generics) -> TokenStream2 { + if generics.params.is_empty() { + return quote!(); + } + let params = &generics.params; + quote!(<#params>) +} + +/// Returns `::` for turbofish use, or nothing if no generic params. +fn generic_turbofish(generics: &Generics) -> TokenStream2 { + if generics.params.is_empty() { + return quote!(); + } + let args: Vec = generics + .params + .iter() + .map(|p| match p { + GenericParam::Type(tp) => tp.ident.to_token_stream(), + GenericParam::Lifetime(lp) => lp.lifetime.to_token_stream(), + GenericParam::Const(cp) => cp.ident.to_token_stream(), + }) + .collect(); + quote!(::<#(#args),*>) +} + +/// Returns `T: thrust_models::Model` predicates for every type param that does not +/// already carry an `Fn`, `FnOnce`, or `FnMut` bound. +fn model_where_predicates(generics: &Generics) -> Vec { + generics + .params + .iter() + .filter_map(|p| { + let GenericParam::Type(tp) = p else { + return None; + }; + let has_fn_bound = tp.bounds.iter().any(|b| { + let TypeParamBound::Trait(tb) = b else { + return false; + }; + tb.path.segments.last().map_or(false, |s| { + matches!(s.ident.to_string().as_str(), "Fn" | "FnOnce" | "FnMut") + }) + }); + if has_fn_bound { + return None; + } + let ident = &tp.ident; + Some(syn::parse_quote!(#ident: thrust_models::Model)) + }) + .collect() +} + +/// Builds `where , `. +/// Returns an empty token stream when both sets are empty. +fn extended_where_clause(generics: &Generics, model_preds: &[WherePredicate]) -> TokenStream2 { + let existing: Vec<&WherePredicate> = generics + .where_clause + .as_ref() + .map(|wc| wc.predicates.iter().collect()) + .unwrap_or_default(); + + if existing.is_empty() && model_preds.is_empty() { + return quote!(); + } + + quote! { where #(#existing,)* #(#model_preds),* } +} + +/// Maps each typed function parameter `x: T` to `x: ::Ty`. +/// Receiver (`self`) parameters are skipped. +fn fn_params_with_model_ty( + inputs: &syn::punctuated::Punctuated, +) -> TokenStream2 { + let params: Vec = inputs + .iter() + .filter_map(|arg| { + let FnArg::Typed(pt) = arg else { return None }; + let pat = &pt.pat; + let ty = &pt.ty; + let model_ty: Type = syn::parse_quote!(<#ty as thrust_models::Model>::Ty); + Some(quote!(#pat: #model_ty)) + }) + .collect(); + quote!(#(#params),*) +} + +/// Extracts argument identifiers from function parameters for use in a call expression. +fn fn_arg_idents(inputs: &syn::punctuated::Punctuated) -> TokenStream2 { + let args: Vec = inputs + .iter() + .filter_map(|arg| match arg { + FnArg::Typed(pt) => { + if let Pat::Ident(pi) = &*pt.pat { + Some(pi.ident.to_token_stream()) + } else { + None + } + } + FnArg::Receiver(_) => Some(quote!(self)), + }) + .collect(); + quote!(#(#args),*) +}