Skip to content
202 changes: 177 additions & 25 deletions tokio/src/fs/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ use crate::blocking::{spawn_blocking, spawn_mandatory_blocking};
#[cfg(not(test))]
use std::fs::File as StdFile;

cfg_io_uring! {
#[cfg(test)]
use super::mocks::spawn;
#[cfg(not(test))]
use crate::spawn;
}

/// A reference to an open file on the filesystem.
///
/// This is a specialized version of [`std::fs::File`] for usage from the
Expand Down Expand Up @@ -753,20 +760,10 @@ impl AsyncWrite for File {
let n = buf.copy_from(src, me.max_buf_size);
let std = me.std.clone();

let blocking_task_join_handle = spawn_mandatory_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
buf.write_to(&mut &*std)
};

(Operation::Write(res), buf)
})
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "background task failed")
})?;
#[allow(unused_mut)]
let mut task_join_handle = inner.poll_write_inner((std, buf), seek)?;

inner.state = State::Busy(blocking_task_join_handle);
inner.state = State::Busy(task_join_handle);

return Poll::Ready(Ok(n));
}
Expand Down Expand Up @@ -824,20 +821,88 @@ impl AsyncWrite for File {
let n = buf.copy_from_bufs(bufs, me.max_buf_size);
let std = me.std.clone();

let blocking_task_join_handle = spawn_mandatory_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
buf.write_to(&mut &*std)
};
#[allow(unused_mut)]
let mut data = Some((std, buf));

let mut task_join_handle: Option<JoinHandle<_>> = None;

#[cfg(all(
tokio_unstable,
feature = "io-uring",
feature = "rt",
feature = "fs",
target_os = "linux"
))]
{
use crate::runtime::Handle;

// Handle not present in some tests?
if let Ok(handle) = Handle::try_current() {
if handle.inner.driver().io().check_and_init()? {
task_join_handle = {
use crate::{io::uring::utils::ArcFd, runtime::driver::op::Op};

let (std, mut buf) = data.take().unwrap();
if let Some(seek) = seek {
// we do std seek before a write, so we can always use u64::MAX (current cursor) for the file offset
// seeking only modifies kernel metadata and does not block, so we can do it here
(&*std).seek(seek).map_err(|e| {
io::Error::new(
e.kind(),
format!("failed to seek before write: {e}"),
)
})?;
}

let mut fd: ArcFd = std;
let handle = spawn(async move {
loop {
let op = Op::write_at(fd, buf, u64::MAX);
let (r, _buf, _fd) = op.await;
buf = _buf;
fd = _fd;
match r {
Ok(_) if buf.is_empty() => {
break (Operation::Write(Ok(())), buf);
}
Ok(0) => {
break (
Operation::Write(Err(
io::ErrorKind::WriteZero.into(),
)),
buf,
);
}
Ok(_) => continue, // more to write
Err(e) => break (Operation::Write(Err(e)), buf),
}
}
});

Some(handle)
};
}
}
}

(Operation::Write(res), buf)
})
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "background task failed")
})?;
if let Some((std, mut buf)) = data {
task_join_handle = Some(
spawn_mandatory_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
buf.write_to(&mut &*std)
};

(Operation::Write(res), buf)
})
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "background task failed")
})?,
);
}

inner.state = State::Busy(blocking_task_join_handle);
inner.state = State::Busy(task_join_handle.unwrap());

return Poll::Ready(Ok(n));
}
Expand Down Expand Up @@ -989,6 +1054,93 @@ impl Inner {
Operation::Seek(_) => Poll::Ready(Ok(())),
}
}

fn poll_write_inner(
&self,
data: (Arc<StdFile>, Buf),
seek: Option<SeekFrom>,
) -> io::Result<JoinHandle<(Operation, Buf)>> {
#[allow(unused_mut)]
let mut data = Some(data);
let mut task_join_handle = None;

#[cfg(all(
tokio_unstable,
feature = "io-uring",
feature = "rt",
feature = "fs",
target_os = "linux"
))]
{
use crate::runtime::Handle;

// Handle not present in some tests?
if let Ok(handle) = Handle::try_current() {
if handle.inner.driver().io().check_and_init()? {
task_join_handle = {
use crate::{io::uring::utils::ArcFd, runtime::driver::op::Op};

let (std, mut buf) = data.take().unwrap();
if let Some(seek) = seek {
// we do std seek before a write, so we can always use u64::MAX (current cursor) for the file offset
// seeking only modifies kernel metadata and does not block, so we can do it here
(&*std).seek(seek).map_err(|e| {
io::Error::new(
e.kind(),
format!("failed to seek before write: {e}"),
)
})?;
}

let mut fd: ArcFd = std;
let handle = spawn(async move {
loop {
let op = Op::write_at(fd, buf, u64::MAX);
let (r, _buf, _fd) = op.await;
buf = _buf;
fd = _fd;
match r {
Ok(_) if buf.is_empty() => {
break (Operation::Write(Ok(())), buf);
}
Ok(0) => {
break (
Operation::Write(Err(io::ErrorKind::WriteZero.into())),
buf,
);
}

Ok(_) => continue, // more to write
Err(e) => break (Operation::Write(Err(e)), buf),
}
}
});

Some(handle)
};
}
}
}

if let Some((std, mut buf)) = data {
task_join_handle = {
let handle = spawn_mandatory_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
buf.write_to(&mut &*std)
};

(Operation::Write(res), buf)
})
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "background task failed"))?;

Some(handle)
};
}

Ok(task_join_handle.unwrap())
}
}

#[cfg(test)]
Expand Down
16 changes: 16 additions & 0 deletions tokio/src/fs/mocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,22 @@ where
Some(JoinHandle { rx })
}

#[allow(dead_code)]
pub(super) fn spawn<F>(f: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();

let _task = crate::spawn(async move {
let res = f.await;
let _ = tx.send(res);
});

JoinHandle { rx }
}

impl<T> Future for JoinHandle<T> {
type Output = Result<T, io::Error>;

Expand Down
35 changes: 21 additions & 14 deletions tokio/src/fs/write.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
#[cfg(all(
tokio_unstable,
feature = "io-uring",
feature = "rt",
feature = "fs",
target_os = "linux"
))]
use crate::io::blocking;
use crate::{fs::asyncify, util::as_ref::OwnedBuf};

use std::{io, path::Path};
Expand Down Expand Up @@ -25,7 +33,6 @@ use std::{io, path::Path};
/// ```
pub async fn write(path: impl AsRef<Path>, contents: impl AsRef<[u8]>) -> io::Result<()> {
let path = path.as_ref();
let contents = crate::util::as_ref::upgrade(contents);

#[cfg(all(
tokio_unstable,
Expand All @@ -38,10 +45,13 @@ pub async fn write(path: impl AsRef<Path>, contents: impl AsRef<[u8]>) -> io::Re
let handle = crate::runtime::Handle::current();
let driver_handle = handle.inner.driver().io();
if driver_handle.check_and_init()? {
return write_uring(path, contents).await;
let mut buf = blocking::Buf::with_capacity(contents.as_ref().len());
buf.copy_from(contents.as_ref(), contents.as_ref().len());
return write_uring(path, buf).await;
}
}

let contents = crate::util::as_ref::upgrade(contents);
write_spawn_blocking(path, contents).await
}

Expand All @@ -52,9 +62,9 @@ pub async fn write(path: impl AsRef<Path>, contents: impl AsRef<[u8]>) -> io::Re
feature = "fs",
target_os = "linux"
))]
async fn write_uring(path: &Path, mut buf: OwnedBuf) -> io::Result<()> {
use crate::{fs::OpenOptions, runtime::driver::op::Op};
use std::os::fd::OwnedFd;
async fn write_uring(path: &Path, mut buf: blocking::Buf) -> io::Result<()> {
use crate::{fs::OpenOptions, io::uring::utils::ArcFd, runtime::driver::op::Op};
use std::sync::Arc;

let file = OpenOptions::new()
.write(true)
Expand All @@ -63,16 +73,14 @@ async fn write_uring(path: &Path, mut buf: OwnedBuf) -> io::Result<()> {
.open(path)
.await?;

let mut fd: OwnedFd = file
.try_into_std()
.expect("unexpected in-flight operation detected")
.into();
let mut fd: ArcFd = Arc::new(
file.try_into_std()
.expect("unexpected in-flight operation detected"),
);

let total: usize = buf.as_ref().len();
let mut buf_offset: usize = 0;
let mut file_offset: u64 = 0;
while buf_offset < total {
let (n, _buf, _fd) = Op::write_at(fd, buf, buf_offset, file_offset)?.await;
while !buf.is_empty() {
let (n, _buf, _fd) = Op::write_at(fd, buf, file_offset).await;
// TODO: handle EINT here
let n = n?;
if n == 0 {
Expand All @@ -81,7 +89,6 @@ async fn write_uring(path: &Path, mut buf: OwnedBuf) -> io::Result<()> {

buf = _buf;
fd = _fd;
buf_offset += n as usize;
file_offset += n as u64;
}

Expand Down
19 changes: 19 additions & 0 deletions tokio/src/io/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,25 @@ impl Buf {
&self.buf[self.pos..]
}

#[cfg(all(
tokio_unstable,
feature = "io-uring",
feature = "rt",
feature = "fs",
target_os = "linux"
))]
pub(crate) fn advance(&mut self, n: usize) {
if n > self.len() {
panic!("advance past end of buffer");
}

self.pos += n;
if self.pos == self.buf.len() {
self.buf.truncate(0);
self.pos = 0;
}
}

/// # Safety
///
/// `rd` must not read from the buffer `read` is borrowing and must correctly
Expand Down
4 changes: 4 additions & 0 deletions tokio/src/io/uring/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::os::fd::AsRawFd;
use std::os::unix::ffi::OsStrExt;
use std::sync::Arc;
use std::{ffi::CString, io, path::Path};

pub(crate) type ArcFd = Arc<dyn AsRawFd + Send + Sync + 'static>;

pub(crate) fn cstr(p: &Path) -> io::Result<CString> {
Ok(CString::new(p.as_os_str().as_bytes())?)
}
Loading