Skip to content

Commit 980ac53

Browse files
committed
rt: add support for non-send closures for thread (un)parking
Add support for non `Send`+`Sync` closures for thread parking and unparking callbacks when using a `LocalRuntime`. Since a `LocalRuntime` will always run its tasks on the same thread, its safe to accept a non `Send`+`Sync` closure. Signed-off-by: Sanskar Jaiswal <[email protected]>
1 parent cae083a commit 980ac53

File tree

3 files changed

+223
-15
lines changed

3 files changed

+223
-15
lines changed

tokio/src/runtime/builder.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ impl Builder {
975975
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
976976
pub fn build_local(&mut self, options: LocalOptions) -> io::Result<LocalRuntime> {
977977
match &self.kind {
978-
Kind::CurrentThread => self.build_current_thread_local_runtime(),
978+
Kind::CurrentThread => self.build_current_thread_local_runtime(options),
979979
#[cfg(feature = "rt-multi-thread")]
980980
Kind::MultiThread => panic!("multi_thread is not supported for LocalRuntime"),
981981
}
@@ -1522,11 +1522,16 @@ impl Builder {
15221522
}
15231523

15241524
#[cfg(tokio_unstable)]
1525-
fn build_current_thread_local_runtime(&mut self) -> io::Result<LocalRuntime> {
1525+
fn build_current_thread_local_runtime(
1526+
&mut self,
1527+
opts: LocalOptions,
1528+
) -> io::Result<LocalRuntime> {
15261529
use crate::runtime::local_runtime::LocalRuntimeScheduler;
15271530

15281531
let tid = std::thread::current().id();
15291532

1533+
self.before_park = opts.before_park;
1534+
self.after_unpark = opts.after_unpark;
15301535
let (scheduler, handle, blocking_pool) =
15311536
self.build_current_thread_runtime_components(Some(tid))?;
15321537

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,159 @@
11
use std::marker::PhantomData;
22

3+
use crate::runtime::Callback;
4+
35
/// [`LocalRuntime`]-only config options
46
///
5-
/// Currently, there are no such options, but in the future, things like `!Send + !Sync` hooks may
6-
/// be added.
7-
///
87
/// Use `LocalOptions::default()` to create the default set of options. This type is used with
98
/// [`Builder::build_local`].
109
///
10+
/// When using [`Builder::build_local`], this overrides any pre-configured options set on the
11+
/// [`Builder`].
12+
///
1113
/// [`Builder::build_local`]: crate::runtime::Builder::build_local
1214
/// [`LocalRuntime`]: crate::runtime::LocalRuntime
13-
#[derive(Default, Debug)]
15+
/// [`Builder`]: crate::runtime::Builder
16+
#[derive(Default)]
1417
#[non_exhaustive]
18+
#[allow(missing_debug_implementations)]
1519
pub struct LocalOptions {
1620
/// Marker used to make this !Send and !Sync.
1721
_phantom: PhantomData<*mut u8>,
22+
23+
/// To run before the local runtime is parked.
24+
pub(crate) before_park: Option<Callback>,
25+
26+
/// To run before the local runtime is spawned.
27+
pub(crate) after_unpark: Option<Callback>,
28+
}
29+
30+
impl std::fmt::Debug for LocalOptions {
31+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32+
f.debug_struct("LocalOptions")
33+
.field("before_park", &self.before_park.as_ref().map(|_| "..."))
34+
.field("after_unpark", &self.after_unpark.as_ref().map(|_| "..."))
35+
.finish()
36+
}
37+
}
38+
39+
impl LocalOptions {
40+
/// Executes function `f` just before the local runtime is parked (goes idle).
41+
/// `f` is called within the Tokio context, so functions like [`tokio::spawn`](crate::spawn)
42+
/// can be called, and may result in this thread being unparked immediately.
43+
///
44+
/// This can be used to start work only when the executor is idle, or for bookkeeping
45+
/// and monitoring purposes.
46+
///
47+
/// This differs from the [`Builder::on_thread_park`] method in that it accepts a non Send + Sync
48+
/// closure.
49+
///
50+
/// Note: There can only be one park callback for a runtime; calling this function
51+
/// more than once replaces the last callback defined, rather than adding to it.
52+
///
53+
/// # Examples
54+
///
55+
/// ```
56+
/// # use tokio::runtime::{Builder, LocalOptions};
57+
/// # pub fn main() {
58+
/// let (tx, rx) = std::sync::mpsc::channel();
59+
/// let mut opts = LocalOptions::default();
60+
/// opts.on_thread_park(move || match rx.recv() {
61+
/// Ok(x) => println!("Received from channel: {}", x),
62+
/// Err(e) => println!("Error receiving from channel: {}", e),
63+
/// });
64+
///
65+
/// let runtime = Builder::new_current_thread()
66+
/// .enable_time()
67+
/// .build_local(opts)
68+
/// .unwrap();
69+
///
70+
/// runtime.block_on(async {
71+
/// tokio::task::spawn_local(async move {
72+
/// tx.send(42).unwrap();
73+
/// });
74+
/// tokio::time::sleep(std::time::Duration::from_millis(1)).await;
75+
/// })
76+
/// # }
77+
/// ```
78+
///
79+
/// [`Builder`]: crate::runtime::Builder
80+
/// [`Builder::on_thread_park`]: crate::runtime::Builder::on_thread_park
81+
pub fn on_thread_park<F>(&mut self, f: F) -> &mut Self
82+
where
83+
F: Fn() + 'static,
84+
{
85+
self.before_park = Some(std::sync::Arc::new(to_send_sync(f)));
86+
self
87+
}
88+
89+
/// Executes function `f` just after the local runtime unparks (starts executing tasks).
90+
///
91+
/// This is intended for bookkeeping and monitoring use cases; note that work
92+
/// in this callback will increase latencies when the application has allowed one or
93+
/// more runtime threads to go idle.
94+
///
95+
/// This differs from the [`Builder::on_thread_unpark`] method in that it accepts a non Send + Sync
96+
/// closure.
97+
///
98+
/// Note: There can only be one unpark callback for a runtime; calling this function
99+
/// more than once replaces the last callback defined, rather than adding to it.
100+
///
101+
/// # Examples
102+
///
103+
/// ```
104+
/// # use tokio::runtime::{Builder, LocalOptions};
105+
/// # pub fn main() {
106+
/// let (tx, rx) = std::sync::mpsc::channel();
107+
/// let mut opts = LocalOptions::default();
108+
/// opts.on_thread_unpark(move || match rx.recv() {
109+
/// Ok(x) => println!("Received from channel: {}", x),
110+
/// Err(e) => println!("Error receiving from channel: {}", e),
111+
/// });
112+
///
113+
/// let runtime = Builder::new_current_thread()
114+
/// .enable_time()
115+
/// .build_local(opts)
116+
/// .unwrap();
117+
///
118+
/// runtime.block_on(async {
119+
/// tokio::task::spawn_local(async move {
120+
/// tx.send(42).unwrap();
121+
/// });
122+
/// tokio::time::sleep(std::time::Duration::from_millis(1)).await;
123+
/// })
124+
/// # }
125+
/// ```
126+
///
127+
/// [`Builder`]: crate::runtime::Builder
128+
/// [`Builder::on_thread_unpark`]: crate::runtime::Builder::on_thread_unpark
129+
pub fn on_thread_unpark<F>(&mut self, f: F) -> &mut Self
130+
where
131+
F: Fn() + 'static,
132+
{
133+
self.after_unpark = Some(std::sync::Arc::new(to_send_sync(f)));
134+
self
135+
}
136+
}
137+
138+
// A wrapper type to allow non-Send + Sync closures to be used in a Send + Sync context.
139+
// This is specifically used for executing callbacks when using a `LocalRuntime`.
140+
struct UnsafeSendSync<T>(T);
141+
142+
// SAFETY: This type is only used in a context where it is guaranteed that the closure will not be
143+
// sent across threads.
144+
unsafe impl<T> Send for UnsafeSendSync<T> {}
145+
unsafe impl<T> Sync for UnsafeSendSync<T> {}
146+
147+
impl<T: Fn()> UnsafeSendSync<T> {
148+
fn call(&self) {
149+
(self.0)()
150+
}
151+
}
152+
153+
fn to_send_sync<F>(f: F) -> impl Fn() + Send + Sync
154+
where
155+
F: Fn(),
156+
{
157+
let f = UnsafeSendSync(f);
158+
move || f.call()
18159
}

tokio/tests/rt_local.rs

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use tokio::task::LocalSet;
88

99
#[test]
1010
fn test_spawn_local_in_runtime() {
11-
let rt = rt();
11+
let rt = rt(LocalOptions::default());
1212

1313
let res = rt.block_on(async move {
1414
let (tx, rx) = tokio::sync::oneshot::channel();
@@ -24,9 +24,71 @@ fn test_spawn_local_in_runtime() {
2424
assert_eq!(res, 5);
2525
}
2626

27+
#[test]
28+
fn test_on_thread_park_unpark_in_runtime() {
29+
let mut opts = LocalOptions::default();
30+
31+
// the refcell makes the below callbacks `!Send + !Sync`
32+
let on_park_called = std::rc::Rc::new(std::cell::RefCell::new(false));
33+
let on_park_cc = on_park_called.clone();
34+
opts.on_thread_park(move || {
35+
*on_park_cc.borrow_mut() = true;
36+
});
37+
38+
let on_unpark_called = std::rc::Rc::new(std::cell::RefCell::new(false));
39+
let on_unpark_cc = on_unpark_called.clone();
40+
opts.on_thread_unpark(move || {
41+
*on_unpark_cc.borrow_mut() = true;
42+
});
43+
let rt = rt(opts);
44+
45+
rt.block_on(async move {
46+
let (tx, rx) = tokio::sync::oneshot::channel();
47+
48+
spawn_local(async {
49+
tokio::task::yield_now().await;
50+
tx.send(5).unwrap();
51+
});
52+
53+
// this ensures on_thread_park is called
54+
rx.await.unwrap()
55+
});
56+
57+
assert!(*on_park_called.borrow());
58+
assert!(*on_unpark_called.borrow());
59+
}
60+
61+
#[test]
62+
fn test_on_thread_park_unpark_in_handle() {
63+
let mut opts = LocalOptions::default();
64+
65+
// the refcell makes the below callbacks `!Send + !Sync`
66+
let on_park_called = std::rc::Rc::new(std::cell::RefCell::new(false));
67+
let on_park_cc = on_park_called.clone();
68+
opts.on_thread_park(move || {
69+
*on_park_cc.borrow_mut() = true;
70+
});
71+
72+
let on_unpark_called = std::rc::Rc::new(std::cell::RefCell::new(false));
73+
let on_unpark_cc = on_unpark_called.clone();
74+
opts.on_thread_unpark(move || {
75+
*on_unpark_cc.borrow_mut() = true;
76+
});
77+
let rt = rt(opts);
78+
79+
rt.block_on(async move {
80+
tokio::task::yield_now().await;
81+
});
82+
83+
// assert that the callbacks were not called - `Handle::block_on` can not drive IO or timer
84+
// drivers on a current-thread runtime, so the park/unpark callbacks should not be invoked.
85+
assert!(!*on_park_called.borrow());
86+
assert!(!*on_unpark_called.borrow());
87+
}
88+
2789
#[test]
2890
fn test_spawn_from_handle() {
29-
let rt = rt();
91+
let rt = rt(LocalOptions::default());
3092

3193
let (tx, rx) = tokio::sync::oneshot::channel();
3294

@@ -42,7 +104,7 @@ fn test_spawn_from_handle() {
42104

43105
#[test]
44106
fn test_spawn_local_on_runtime_object() {
45-
let rt = rt();
107+
let rt = rt(LocalOptions::default());
46108

47109
let (tx, rx) = tokio::sync::oneshot::channel();
48110

@@ -58,7 +120,7 @@ fn test_spawn_local_on_runtime_object() {
58120

59121
#[test]
60122
fn test_spawn_local_from_guard() {
61-
let rt = rt();
123+
let rt = rt(LocalOptions::default());
62124

63125
let (tx, rx) = tokio::sync::oneshot::channel();
64126

@@ -80,7 +142,7 @@ fn test_spawn_from_guard_other_thread() {
80142
let (tx, rx) = std::sync::mpsc::channel();
81143

82144
std::thread::spawn(move || {
83-
let rt = rt();
145+
let rt = rt(LocalOptions::default());
84146
let handle = rt.handle().clone();
85147

86148
tx.send(handle).unwrap();
@@ -100,7 +162,7 @@ fn test_spawn_local_from_guard_other_thread() {
100162
let (tx, rx) = std::sync::mpsc::channel();
101163

102164
std::thread::spawn(move || {
103-
let rt = rt();
165+
let rt = rt(LocalOptions::default());
104166
let handle = rt.handle().clone();
105167

106168
tx.send(handle).unwrap();
@@ -123,7 +185,7 @@ fn test_spawn_local_from_guard_other_thread() {
123185
#[test]
124186
#[cfg_attr(target_family = "wasm", ignore)] // threads not supported
125187
fn test_spawn_local_panic() {
126-
let rt = rt();
188+
let rt = rt(LocalOptions::default());
127189
let local = LocalSet::new();
128190

129191
rt.block_on(local.run_until(async {
@@ -162,9 +224,9 @@ fn test_spawn_local_in_multi_thread_runtime() {
162224
})
163225
}
164226

165-
fn rt() -> tokio::runtime::LocalRuntime {
227+
fn rt(opts: LocalOptions) -> tokio::runtime::LocalRuntime {
166228
tokio::runtime::Builder::new_current_thread()
167229
.enable_all()
168-
.build_local(LocalOptions::default())
230+
.build_local(opts)
169231
.unwrap()
170232
}

0 commit comments

Comments
 (0)