use std::{
io::Write,
path::Path,
process::{Command, Stdio},
};
use anyhow::{ensure, Result};
use rustc_data_structures::{captures::Captures, fx::FxHashMap as HashMap};
use rustc_hir::{def_id::DefId, CoroutineDesugaring, CoroutineKind, HirId};
use rustc_middle::{
mir::{pretty::write_mir_fn, *},
ty::{Region, Ty, TyCtxt},
};
use smallvec::SmallVec;
use super::control_dependencies::ControlDependencies;
use crate::{PlaceExt, TyExt};
pub trait BodyExt<'tcx> {
type AllReturnsIter<'a>: Iterator<Item = Location>
where
Self: 'a;
fn all_returns(&self) -> Self::AllReturnsIter<'_>;
type AllLocationsIter<'a>: Iterator<Item = Location>
where
Self: 'a;
fn all_locations(&self) -> Self::AllLocationsIter<'_>;
type LocationsIter: Iterator<Item = Location>;
fn locations_in_block(&self, block: BasicBlock) -> Self::LocationsIter;
fn debug_info_name_map(&self) -> HashMap<String, Local>;
fn to_string(&self, tcx: TyCtxt<'tcx>) -> Result<String>;
fn location_to_hir_id(&self, location: Location) -> HirId;
fn source_info_to_hir_id(&self, info: &SourceInfo) -> HirId;
fn control_dependencies(&self) -> ControlDependencies<BasicBlock>;
fn async_context(&self, tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Ty<'tcx>>;
type PlacesIter<'a>: Iterator<Item = Place<'tcx>>
where
Self: 'a;
fn all_places(&self, tcx: TyCtxt<'tcx>, def_id: DefId) -> Self::PlacesIter<'_>;
type ArgRegionsIter<'a>: Iterator<Item = Region<'tcx>>
where
Self: 'a;
fn regions_in_args(&self) -> Self::ArgRegionsIter<'_>;
type ReturnRegionsIter: Iterator<Item = Region<'tcx>>;
fn regions_in_return(&self) -> Self::ReturnRegionsIter;
}
impl<'tcx> BodyExt<'tcx> for Body<'tcx> {
type AllReturnsIter<'a> = impl Iterator<Item = Location> + Captures<'tcx> + 'a where Self: 'a;
fn all_returns(&self) -> Self::AllReturnsIter<'_> {
self
.basic_blocks
.iter_enumerated()
.filter_map(|(block, data)| match data.terminator().kind {
TerminatorKind::Return => Some(Location {
block,
statement_index: data.statements.len(),
}),
_ => None,
})
}
type AllLocationsIter<'a> = impl Iterator<Item = Location> + Captures<'tcx> + 'a where Self: 'a;
fn all_locations(&self) -> Self::AllLocationsIter<'_> {
self
.basic_blocks
.iter_enumerated()
.flat_map(|(block, data)| {
(0 .. data.statements.len() + 1).map(move |statement_index| Location {
block,
statement_index,
})
})
}
type LocationsIter = impl Iterator<Item = Location>;
fn locations_in_block(&self, block: BasicBlock) -> Self::LocationsIter {
let num_stmts = self.basic_blocks[block].statements.len();
(0 ..= num_stmts).map(move |statement_index| Location {
block,
statement_index,
})
}
fn debug_info_name_map(&self) -> HashMap<String, Local> {
self
.var_debug_info
.iter()
.filter_map(|info| match info.value {
VarDebugInfoContents::Place(place) => Some((info.name.to_string(), place.local)),
_ => None,
})
.collect()
}
fn to_string(&self, tcx: TyCtxt<'tcx>) -> Result<String> {
let mut buffer = Vec::new();
write_mir_fn(tcx, self, &mut |_, _| Ok(()), &mut buffer)?;
Ok(String::from_utf8(buffer)?)
}
fn location_to_hir_id(&self, location: Location) -> HirId {
let source_info = self.source_info(location);
self.source_info_to_hir_id(source_info)
}
fn source_info_to_hir_id(&self, info: &SourceInfo) -> HirId {
let scope = &self.source_scopes[info.scope];
let local_data = scope.local_data.as_ref().assert_crate_local();
local_data.lint_root
}
fn control_dependencies(&self) -> ControlDependencies<BasicBlock> {
ControlDependencies::build_many(
&self.basic_blocks,
self.all_returns().map(|loc| loc.block),
)
}
fn async_context(&self, tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Ty<'tcx>> {
if matches!(
tcx.coroutine_kind(def_id),
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _))
) {
Some(self.local_decls[Local::from_usize(2)].ty)
} else {
None
}
}
type ArgRegionsIter<'a> = impl Iterator<Item = Region<'tcx>> + Captures<'tcx> + 'a
where Self: 'a;
type ReturnRegionsIter = impl Iterator<Item = Region<'tcx>>;
type PlacesIter<'a> = impl Iterator<Item = Place<'tcx>> + Captures<'tcx> + 'a
where Self: 'a;
fn regions_in_args(&self) -> Self::ArgRegionsIter<'_> {
self
.args_iter()
.flat_map(|arg_local| self.local_decls[arg_local].ty.inner_regions())
}
fn regions_in_return(&self) -> Self::ReturnRegionsIter {
self
.return_ty()
.inner_regions()
.collect::<SmallVec<[Region<'tcx>; 8]>>()
.into_iter()
}
fn all_places(&self, tcx: TyCtxt<'tcx>, def_id: DefId) -> Self::PlacesIter<'_> {
self.local_decls.indices().flat_map(move |local| {
Place::from_local(local, tcx).interior_paths(tcx, self, def_id)
})
}
}
pub fn run_dot(path: &Path, buf: Vec<u8>) -> Result<()> {
let mut p = Command::new("dot")
.args(["-Tpdf", "-o", &path.display().to_string()])
.stdin(Stdio::piped())
.spawn()?;
p.stdin.as_mut().unwrap().write_all(&buf)?;
let status = p.wait()?;
ensure!(status.success(), "dot for {} failed", path.display());
Ok(())
}
#[cfg(test)]
mod test {
use super::BodyExt;
use crate::test_utils;
#[test]
fn test_body_ext() {
let input = r#"
fn foobar<'a>(x: &'a i32, y: &'a i32) -> &'a i32 {
if *x > 0 {
return x;
}
y
}"#;
test_utils::compile_body(input, |_, _, body| {
let body = &body.body;
assert_eq!(body.regions_in_args().count(), 2);
assert_eq!(body.regions_in_return().count(), 1);
});
}
}