Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 142 additions & 6 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use rustc_index::IndexVec;
use rustc_middle::mir::{self, BasicBlock, Local};
use rustc_middle::ty::{self as mir_ty, TyCtxt};
use rustc_span::def_id::{DefId, LocalDefId};
use rustc_span::symbol::Symbol;

use crate::analyze;
use crate::annot::{AnnotFormula, AnnotParser, Resolver};
Expand All @@ -24,11 +25,38 @@ use crate::refine::{self, BasicBlockType, TypeBuilder};
use crate::rty;

mod annot;
mod annot_fn;
mod basic_block;
mod crate_;
mod did_cache;
mod local_def;

pub fn mir_borrowck_skip_formula_fn(
tcx: rustc_middle::ty::TyCtxt<'_>,
local_def_id: rustc_span::def_id::LocalDefId,
) -> rustc_middle::query::queries::mir_borrowck::ProvidedValue {
// TODO: unify impl with local_def::Analyzer
let is_annotated_as_formula_fn = tcx
.get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::formula_fn_path())
.next()
.is_some();

if is_annotated_as_formula_fn {
tracing::debug!(?local_def_id, "skipping borrow check for formula fn");
let dummy_result = rustc_middle::mir::BorrowCheckResult {
concrete_opaque_types: Default::default(),
closure_requirements: None,
used_mut_upvars: Default::default(),
tainted_by_errors: None,
};
return tcx.arena.alloc(dummy_result);
}

(rustc_interface::DEFAULT_QUERY_PROVIDERS
.queries
.mir_borrowck)(tcx, local_def_id)
}

pub fn local_of_function_param(idx: rty::FunctionParamIdx) -> Local {
Local::from(idx.index() + 1)
}
Expand Down Expand Up @@ -155,6 +183,9 @@ pub struct Analyzer<'tcx> {
/// (at least for every defs referenced by local def bodies)
defs: HashMap<DefId, DefTy<'tcx>>,

/// Collection of functions with `#[thrust::formula_fn]` attribute.
formula_fns: HashMap<DefId, annot_fn::FormulaFn<'tcx>>,

/// Resulting CHC system.
system: Rc<RefCell<chc::System>>,

Expand All @@ -181,12 +212,14 @@ impl<'tcx> Analyzer<'tcx> {
impl<'tcx> Analyzer<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>) -> Self {
let defs = Default::default();
let formula_fns = Default::default();
let system = Default::default();
let basic_blocks = Default::default();
let enum_defs = Default::default();
Self {
tcx,
defs,
formula_fns,
system,
basic_blocks,
def_ids: did_cache::DefIdCache::new(tcx),
Expand Down Expand Up @@ -346,6 +379,11 @@ impl<'tcx> Analyzer<'tcx> {
Some(expected)
}

pub fn register_formula_fn(&mut self, def_id: DefId, formula_fn: annot_fn::FormulaFn<'tcx>) {
tracing::info!(def_id = ?def_id, formula_fn = %formula_fn.display(), "register_formula_fn");
self.formula_fns.insert(def_id, formula_fn);
}

pub fn register_basic_block_ty(
&mut self,
def_id: LocalDefId,
Expand Down Expand Up @@ -434,20 +472,71 @@ impl<'tcx> Analyzer<'tcx> {
self.fn_sig_with_body(def_id, body)
}

fn extract_path_with_attr(
&self,
local_def_id: LocalDefId,
attr_path: &[Symbol],
) -> Option<DefId> {
let map = self.tcx.hir();
let Some(body_id) = map.maybe_body_owned_by(local_def_id) else {
return None;
};
let body = map.body(body_id);

let rustc_hir::ExprKind::Block(block, _) = body.value.kind else {
return None;
};
for stmt in block.stmts {
if map
.attrs(stmt.hir_id)
.iter()
.all(|attr| !attr.path_matches(attr_path))
{
continue;
}
let rustc_hir::StmtKind::Semi(expr) = stmt.kind else {
panic!(
"annotated path is expected to be a semi statement, but found: {:?}",
stmt.kind
);
};
let rustc_hir::ExprKind::Path(qpath) = expr.kind else {
panic!(
"annotated path is expected to be a path expression, but found: {:?}",
expr.kind
);
};
let rustc_hir::QPath::Resolved(_, path) = qpath else {
panic!(
"qpath is unexpectedly not resolved in annotated path: {:?}",
qpath
);
};
let rustc_hir::def::Res::Def(_, def_id) = path.res else {
panic!(
"annotated path is expected to refer to a def, but found: {:?}",
path.res
);
};
return Some(def_id);
}
None
}

fn extract_require_annot<T>(
&self,
def_id: DefId,
local_def_id: LocalDefId,
resolver: T,
self_type_name: Option<String>,
) -> Option<AnnotFormula<T::Output>>
where
T: Resolver,
T: Resolver<Output = rty::FunctionParamIdx>,
{
let mut require_annot = None;
let parser = AnnotParser::new(&resolver, self_type_name);
for attrs in self
.tcx
.get_attrs_by_path(def_id, &analyze::annot::requires_path())
.get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::requires_path())
{
if require_annot.is_some() {
unimplemented!();
Expand All @@ -456,23 +545,47 @@ impl<'tcx> Analyzer<'tcx> {
let require = parser.parse_formula(ts).unwrap();
require_annot = Some(require);
}

if let Some(formula_def_id) =
self.extract_path_with_attr(local_def_id, &analyze::annot::requires_path())
{
if require_annot.is_some() {
unimplemented!();
}
let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else {
panic!(
"require annotation {:?} is not a formula function",
formula_def_id
);
};
let Some(annot) = formula_fn.to_require_annot(self.fn_sig(local_def_id.to_def_id()))
else {
panic!(
"require annotation {:?} has incompatible signature with {:?}",
formula_def_id, local_def_id
);
};
require_annot = Some(annot);
}

require_annot
}

fn extract_ensure_annot<T>(
&self,
def_id: DefId,
local_def_id: LocalDefId,
resolver: T,
self_type_name: Option<String>,
) -> Option<AnnotFormula<T::Output>>
where
T: Resolver,
T: Resolver<Output = rty::RefinedTypeVar<rty::FunctionParamIdx>>,
{
let mut ensure_annot = None;

let parser = AnnotParser::new(&resolver, self_type_name);
for attrs in self
.tcx
.get_attrs_by_path(def_id, &analyze::annot::ensures_path())
.get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::ensures_path())
{
if ensure_annot.is_some() {
unimplemented!();
Expand All @@ -481,6 +594,29 @@ impl<'tcx> Analyzer<'tcx> {
let ensure = parser.parse_formula(ts).unwrap();
ensure_annot = Some(ensure);
}

if let Some(formula_def_id) =
self.extract_path_with_attr(local_def_id, &analyze::annot::ensures_path())
{
if ensure_annot.is_some() {
unimplemented!();
}
let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else {
panic!(
"ensure annotation {:?} is not a formula function",
formula_def_id
);
};
let Some(annot) = formula_fn.to_ensure_annot(self.fn_sig(local_def_id.to_def_id()))
else {
panic!(
"ensure annotation {:?} has incompatible signature with {:?}",
formula_def_id, local_def_id
);
};
ensure_annot = Some(annot);
}

ensure_annot
}

Expand Down
4 changes: 4 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ pub fn predicate_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("predicate")]
}

pub fn formula_fn_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("formula_fn")]
}

/// A [`annot::Resolver`] implementation for resolving function parameters.
///
/// The parameter names and their sorts needs to be configured via
Expand Down
Loading
Loading