Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tokio/src/runtime/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ impl Builder {
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn build_local(&mut self, options: LocalOptions) -> io::Result<LocalRuntime> {
match &self.kind {
Kind::CurrentThread => self.build_current_thread_local_runtime(),
Kind::CurrentThread => self.build_current_thread_local_runtime(options),
#[cfg(feature = "rt-multi-thread")]
Kind::MultiThread => panic!("multi_thread is not supported for LocalRuntime"),
}
Expand Down Expand Up @@ -1522,11 +1522,16 @@ impl Builder {
}

#[cfg(tokio_unstable)]
fn build_current_thread_local_runtime(&mut self) -> io::Result<LocalRuntime> {
fn build_current_thread_local_runtime(
&mut self,
opts: LocalOptions,
) -> io::Result<LocalRuntime> {
use crate::runtime::local_runtime::LocalRuntimeScheduler;

let tid = std::thread::current().id();

self.before_park = opts.before_park;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this should be done only if opts.on_xyz_abc is Some.

Use case:

let opts = LocalOptions::default();
// no on_before_park/on_after_unpark for opts !

Builder::new_current_thread()
    .enable_time()
    .on_before_park(...)  // 1
    .build_local(opts)    // 2
    .unwrap();

Currently the user application sets on_before_park callback at 1) and then 2) will silently wipe them

self.after_unpark = opts.after_unpark;
Comment on lines +1533 to +1534
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is adding the support for the following hooks also safe? If so, we should also add them.

  • on_after_task_poll
  • on_before_task_poll
  • on_task_spawn
  • on_task_terminate

let (scheduler, handle, blocking_pool) =
self.build_current_thread_runtime_components(Some(tid))?;

Expand Down
149 changes: 145 additions & 4 deletions tokio/src/runtime/local_runtime/options.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,159 @@
use std::marker::PhantomData;

use crate::runtime::Callback;

/// [`LocalRuntime`]-only config options
///
/// Currently, there are no such options, but in the future, things like `!Send + !Sync` hooks may
/// be added.
///
/// Use `LocalOptions::default()` to create the default set of options. This type is used with
/// [`Builder::build_local`].
///
/// When using [`Builder::build_local`], this overrides any pre-configured options set on the
/// [`Builder`].
///
/// [`Builder::build_local`]: crate::runtime::Builder::build_local
/// [`LocalRuntime`]: crate::runtime::LocalRuntime
#[derive(Default, Debug)]
/// [`Builder`]: crate::runtime::Builder
#[derive(Default)]
#[non_exhaustive]
#[allow(missing_debug_implementations)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pub struct LocalOptions {
/// Marker used to make this !Send and !Sync.
_phantom: PhantomData<*mut u8>,
Comment on lines 20 to 21
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s unusual to place PhantomData as the first field — please move it to the end.


/// To run before the local runtime is parked.
pub(crate) before_park: Option<Callback>,

/// To run before the local runtime is spawned.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// To run before the local runtime is spawned.
/// To run after the local runtime is unparked.

pub(crate) after_unpark: Option<Callback>,
}

impl std::fmt::Debug for LocalOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalOptions")
.field("before_park", &self.before_park.as_ref().map(|_| "..."))
.field("after_unpark", &self.after_unpark.as_ref().map(|_| "..."))
.finish()
}
}

impl LocalOptions {
/// Executes function `f` just before the local runtime is parked (goes idle).
/// `f` is called within the Tokio context, so functions like [`tokio::spawn`](crate::spawn)
/// can be called, and may result in this thread being unparked immediately.
///
/// This can be used to start work only when the executor is idle, or for bookkeeping
/// and monitoring purposes.
///
/// This differs from the [`Builder::on_thread_park`] method in that it accepts a non Send + Sync
/// closure.
///
/// Note: There can only be one park callback for a runtime; calling this function
/// more than once replaces the last callback defined, rather than adding to it.
///
/// # Examples
///
/// ```
/// # use tokio::runtime::{Builder, LocalOptions};
/// # pub fn main() {
/// let (tx, rx) = std::sync::mpsc::channel();
/// let mut opts = LocalOptions::default();
/// opts.on_thread_park(move || match rx.recv() {
/// Ok(x) => println!("Received from channel: {}", x),
/// Err(e) => println!("Error receiving from channel: {}", e),
/// });
///
/// let runtime = Builder::new_current_thread()
/// .enable_time()
/// .build_local(opts)
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn_local(async move {
/// tx.send(42).unwrap();
/// });
/// tokio::time::sleep(std::time::Duration::from_millis(1)).await;
/// })
/// # }
/// ```
///
/// [`Builder`]: crate::runtime::Builder
/// [`Builder::on_thread_park`]: crate::runtime::Builder::on_thread_park
pub fn on_thread_park<F>(&mut self, f: F) -> &mut Self
where
F: Fn() + 'static,
{
self.before_park = Some(std::sync::Arc::new(to_send_sync(f)));
self
}

/// Executes function `f` just after the local runtime unparks (starts executing tasks).
///
/// This is intended for bookkeeping and monitoring use cases; note that work
/// in this callback will increase latencies when the application has allowed one or
/// more runtime threads to go idle.
///
/// This differs from the [`Builder::on_thread_unpark`] method in that it accepts a non Send + Sync
/// closure.
///
/// Note: There can only be one unpark callback for a runtime; calling this function
/// more than once replaces the last callback defined, rather than adding to it.
///
/// # Examples
///
/// ```
/// # use tokio::runtime::{Builder, LocalOptions};
/// # pub fn main() {
/// let (tx, rx) = std::sync::mpsc::channel();
/// let mut opts = LocalOptions::default();
/// opts.on_thread_unpark(move || match rx.recv() {
/// Ok(x) => println!("Received from channel: {}", x),
/// Err(e) => println!("Error receiving from channel: {}", e),
/// });
///
/// let runtime = Builder::new_current_thread()
/// .enable_time()
/// .build_local(opts)
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn_local(async move {
/// tx.send(42).unwrap();
/// });
/// tokio::time::sleep(std::time::Duration::from_millis(1)).await;
/// })
/// # }
/// ```
///
/// [`Builder`]: crate::runtime::Builder
/// [`Builder::on_thread_unpark`]: crate::runtime::Builder::on_thread_unpark
pub fn on_thread_unpark<F>(&mut self, f: F) -> &mut Self
where
F: Fn() + 'static,
{
self.after_unpark = Some(std::sync::Arc::new(to_send_sync(f)));
self
}
}

// A wrapper type to allow non-Send + Sync closures to be used in a Send + Sync context.
// This is specifically used for executing callbacks when using a `LocalRuntime`.
struct UnsafeSendSync<T>(T);

// SAFETY: This type is only used in a context where it is guaranteed that the closure will not be
// sent across threads.
unsafe impl<T> Send for UnsafeSendSync<T> {}
unsafe impl<T> Sync for UnsafeSendSync<T> {}

impl<T: Fn()> UnsafeSendSync<T> {
fn call(&self) {
(self.0)()
}
}

fn to_send_sync<F>(f: F) -> impl Fn() + Send + Sync
where
F: Fn(),
{
let f = UnsafeSendSync(f);
move || f.call()
}
80 changes: 71 additions & 9 deletions tokio/tests/rt_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tokio::task::LocalSet;

#[test]
fn test_spawn_local_in_runtime() {
let rt = rt();
let rt = rt(LocalOptions::default());

let res = rt.block_on(async move {
let (tx, rx) = tokio::sync::oneshot::channel();
Expand All @@ -24,9 +24,71 @@ fn test_spawn_local_in_runtime() {
assert_eq!(res, 5);
}

#[test]
fn test_on_thread_park_unpark_in_runtime() {
let mut opts = LocalOptions::default();

// the refcell makes the below callbacks `!Send + !Sync`
let on_park_called = std::rc::Rc::new(std::cell::RefCell::new(false));
let on_park_cc = on_park_called.clone();
opts.on_thread_park(move || {
*on_park_cc.borrow_mut() = true;
});

let on_unpark_called = std::rc::Rc::new(std::cell::RefCell::new(false));
let on_unpark_cc = on_unpark_called.clone();
opts.on_thread_unpark(move || {
*on_unpark_cc.borrow_mut() = true;
});
let rt = rt(opts);

rt.block_on(async move {
let (tx, rx) = tokio::sync::oneshot::channel();

spawn_local(async {
tokio::task::yield_now().await;
tx.send(5).unwrap();
});

// this ensures on_thread_park is called
rx.await.unwrap()
});

assert!(*on_park_called.borrow());
assert!(*on_unpark_called.borrow());
}

#[test]
fn test_on_thread_park_unpark_in_handle() {
let mut opts = LocalOptions::default();

// the refcell makes the below callbacks `!Send + !Sync`
let on_park_called = std::rc::Rc::new(std::cell::RefCell::new(false));
let on_park_cc = on_park_called.clone();
opts.on_thread_park(move || {
*on_park_cc.borrow_mut() = true;
});

let on_unpark_called = std::rc::Rc::new(std::cell::RefCell::new(false));
let on_unpark_cc = on_unpark_called.clone();
opts.on_thread_unpark(move || {
*on_unpark_cc.borrow_mut() = true;
});
let rt = rt(opts);

rt.block_on(async move {
tokio::task::yield_now().await;
});

// assert that the callbacks were not called - `Handle::block_on` can not drive IO or timer
// drivers on a current-thread runtime, so the park/unpark callbacks should not be invoked.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above is not clear to me.
Where is the difference with the code for test_on_thread_park_unpark_in_runtime() ?
Until line https://github.com/tokio-rs/tokio/pull/7420/files#diff-3c0fef031dc5f261b671919a47345cbbc60841e2ffcd3287df279290b7297933R79 they are exactly the same. They differ in the bodies of block_on().

assert!(!*on_park_called.borrow());
assert!(!*on_unpark_called.borrow());
}

#[test]
fn test_spawn_from_handle() {
let rt = rt();
let rt = rt(LocalOptions::default());

let (tx, rx) = tokio::sync::oneshot::channel();

Expand All @@ -42,7 +104,7 @@ fn test_spawn_from_handle() {

#[test]
fn test_spawn_local_on_runtime_object() {
let rt = rt();
let rt = rt(LocalOptions::default());

let (tx, rx) = tokio::sync::oneshot::channel();

Expand All @@ -58,7 +120,7 @@ fn test_spawn_local_on_runtime_object() {

#[test]
fn test_spawn_local_from_guard() {
let rt = rt();
let rt = rt(LocalOptions::default());

let (tx, rx) = tokio::sync::oneshot::channel();

Expand All @@ -80,7 +142,7 @@ fn test_spawn_from_guard_other_thread() {
let (tx, rx) = std::sync::mpsc::channel();

std::thread::spawn(move || {
let rt = rt();
let rt = rt(LocalOptions::default());
let handle = rt.handle().clone();

tx.send(handle).unwrap();
Expand All @@ -100,7 +162,7 @@ fn test_spawn_local_from_guard_other_thread() {
let (tx, rx) = std::sync::mpsc::channel();

std::thread::spawn(move || {
let rt = rt();
let rt = rt(LocalOptions::default());
let handle = rt.handle().clone();

tx.send(handle).unwrap();
Expand All @@ -123,7 +185,7 @@ fn test_spawn_local_from_guard_other_thread() {
#[test]
#[cfg_attr(target_family = "wasm", ignore)] // threads not supported
fn test_spawn_local_panic() {
let rt = rt();
let rt = rt(LocalOptions::default());
let local = LocalSet::new();

rt.block_on(local.run_until(async {
Expand Down Expand Up @@ -162,9 +224,9 @@ fn test_spawn_local_in_multi_thread_runtime() {
})
}

fn rt() -> tokio::runtime::LocalRuntime {
fn rt(opts: LocalOptions) -> tokio::runtime::LocalRuntime {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build_local(LocalOptions::default())
.build_local(opts)
.unwrap()
}