Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@

#![deny(missing_docs, missing_debug_implementations)]

use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::env::VarError;
use std::fmt::Debug;
Expand Down Expand Up @@ -818,6 +819,155 @@ macro_rules! fail_point {
($name:expr, $cond:expr, $e:expr) => {{}};
}

#[derive(Clone)]
struct SyncCallback1(Arc<Mutex<dyn FnMut(&mut dyn Any) + Send + Sync + 'static>>);

impl Debug for SyncCallback1 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("SyncCallback1()")
}
}

impl PartialEq for SyncCallback1 {
#[allow(clippy::vtable_address_comparisons)]
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}

impl SyncCallback1 {
fn new(f: Box<dyn FnMut(&mut dyn Any) + Send + Sync>) -> SyncCallback1 {
SyncCallback1(Arc::new(Mutex::new(f)))
}

fn run(&mut self, var: &mut dyn Any) {
let callback = &mut self.0.lock().unwrap();
callback(var);
}
}

struct MapEntry {
type_id: TypeId,
cb: SyncCallback1,
}

impl MapEntry {
fn new(type_id: TypeId, cb: Box<dyn FnMut(&mut dyn Any) + Send + Sync>) -> MapEntry {
MapEntry {
type_id: type_id,
cb: SyncCallback1::new(cb),
}
}
}

lazy_static::lazy_static! {
static ref TESTVALUE_REGISTRY: RwLock<HashMap<String, MapEntry>>= Default::default();
}

/// Set the callback for a test value adjustment.
///
/// Usage:
///
/// ```rust
/// fn production_code() {
/// ...
/// let mut var = SomeVar();
/// adjust("adjust_this_var", &mut var);
/// ...
/// }
///
/// fn test_code() {
/// ...
/// let _raii = ScopedCallback::new("adjust_this_var", |var| {
/// *var = SomeNewValue();
/// });
/// ...
/// }
/// ```
///
pub fn set_callback<S, T, F>(name: S, mut f: F) -> Result<(), String>
where
S: Into<String>,
T: Any,
F: FnMut(&mut T) + Send + Sync + 'static,
{
let mut registry = TESTVALUE_REGISTRY.write().unwrap();
registry.insert(
name.into(),
MapEntry::new(
TypeId::of::<T>(),
Box::new(move |var| {
if let Some(var) = var.downcast_mut::<T>() {
f(var);
} else {
panic!("Type mismtach");
}
}),
),
);
Ok(())
}

/// Set a scoped callback using RAII
#[derive(Debug)]
pub struct ScopedCallback {
name: String,
}

impl ScopedCallback {
/// Creates a RAII instance.
pub fn new<S, T, F>(name: S, f: F) -> Self
where
S: Into<String> + Copy,
T: Any,
F: FnMut(&mut T) + Send + Sync + 'static,
{
set_callback(name.clone(), f).unwrap();
ScopedCallback { name: name.into() }
}
}

impl Drop for ScopedCallback {
fn drop(&mut self) {
let mut registry = TESTVALUE_REGISTRY.write().unwrap();
registry.remove(&self.name);
}
}

#[doc(hidden)]
pub fn internal_adjust<S, T>(name: S, var: &mut T)
where
S: Into<String>,
T: Clone + 'static,
{
let mut registry = TESTVALUE_REGISTRY.write().unwrap();
// Clone the var here, since the argument is required to be 'static.
let mut clone = var.clone();
if let Some(entry) = registry.get_mut(&name.into()) {
if (*entry).type_id != TypeId::of::<T>() {
panic!("Type mismatch");
}
(*entry).cb.run(&mut clone);
}
*var = clone;
}

/// Define a test value adjustment (requires `failpoints` feature).
#[macro_export]
#[cfg(feature = "failpoints")]
macro_rules! adjust {
($name:expr, $var:expr) => {{
$crate::internal_adjust($name, $var);
}};
}

/// Define a test value adjustment (disabled, see `failpoints` feature).
#[macro_export]
#[cfg(not(feature = "failpoints"))]
macro_rules! adjust {
($name:expr, $var:expr) => {{}};
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
33 changes: 33 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,36 @@ fn test_list() {
fail::cfg("list", "return").unwrap();
assert!(fail::list().contains(&("list".to_string(), "return".to_string())));
}

#[test]
fn test_value_adjust() {
let f = || -> i32 {
let mut var = 1;
fail::adjust!("adjust_var", &mut var);
var
};
assert_eq!(f(), 1);

fail::set_callback("adjust_var", |vari| {
*vari = 2;
})
.unwrap();
assert_eq!(f(), 2);
}

#[test]
fn test_value_adjust_raii() {
let f = || -> i32 {
let mut var = 1;
fail::adjust!("adjust_var1", &mut var);
var
};
{
let _raii = fail::ScopedCallback::new("adjust_var1", |var| {
*var = 2;
});
assert_eq!(f(), 2);
}

assert_eq!(f(), 1);
}