Skip to content

Commit 4667eb2

Browse files
authored
io_uring fix stop facade (#977)
Using the stopping facade writes currently the wrong pointer into the `IORING_OP_ASYNC_CANCEL` operation. This PR fixes this by passing the pointer to `__task` which is used for the original submission.
1 parent f140910 commit 4667eb2

File tree

5 files changed

+60
-39
lines changed

5 files changed

+60
-39
lines changed

.github/workflows/ci.gpu.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ jobs:
4545
cd /workspaces/stdexec;
4646
# Configure
4747
cmake -S . -B build -GNinja \
48+
-DSTDEXEC_ENABLE_IO_URING_TESTS=OFF \
4849
-DSTDEXEC_ENABLE_CUDA=ON \
4950
-DCMAKE_CXX_COMPILER="$cxx" \
5051
-DCMAKE_CUDA_COMPILER="$cxx" \

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ if (STDEXEC_ENABLE_TBB)
276276
)
277277
endif ()
278278

279+
option (STDEXEC_ENABLE_IO_URING_TESTS "Enable io_uring tests" ON)
280+
279281
option(STDEXEC_BUILD_EXAMPLES "Build stdexec examples" ON)
280282
option(STDEXEC_BUILD_TESTS "Build stdexec tests" ON)
281283
option(BUILD_TESTING "" ${STDEXEC_BUILD_TESTS})

include/exec/linux/io_uring_context.hpp

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ namespace exec {
328328
};
329329

330330
class __scheduler;
331-
331+
332332
enum class until {
333333
stopped,
334334
empty
@@ -402,26 +402,28 @@ namespace exec {
402402
///
403403
/// This function is not thread-safe and must only be called from the thread that drives the io context.
404404
void run_some() noexcept {
405-
__n_submitted_ -= __completion_queue_.complete();
405+
__n_total_submitted_ -= __completion_queue_.complete();
406406
STDEXEC_ASSERT(
407-
0 <= __n_submitted_
408-
&& __n_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
409-
__u32 __max_submissions = __params_.cq_entries - static_cast<__u32>(__n_submitted_);
407+
0 <= __n_total_submitted_
408+
&& __n_total_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
409+
__u32 __max_submissions = __params_.cq_entries - static_cast<__u32>(__n_total_submitted_);
410410
__pending_.append(__requests_.pop_all());
411411
__submission_result __result = __submission_queue_.submit(
412412
(__task_queue&&) __pending_, __max_submissions, __stop_source_->stop_requested());
413-
__n_submitted_ += __result.__n_submitted;
414-
STDEXEC_ASSERT(__n_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
413+
__n_total_submitted_ += __result.__n_submitted;
414+
__n_newly_submitted_ += __result.__n_submitted;
415+
STDEXEC_ASSERT(__n_total_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
415416
__pending_ = (__task_queue&&) __result.__pending;
416417
while (!__result.__ready.empty()) {
417-
__n_submitted_ -= __completion_queue_.complete((__task_queue&&) __result.__ready);
418-
STDEXEC_ASSERT(0 <= __n_submitted_);
418+
__n_total_submitted_ -= __completion_queue_.complete((__task_queue&&) __result.__ready);
419+
STDEXEC_ASSERT(0 <= __n_total_submitted_);
419420
__pending_.append(__requests_.pop_all());
420-
__max_submissions = __params_.cq_entries - static_cast<__u32>(__n_submitted_);
421+
__max_submissions = __params_.cq_entries - static_cast<__u32>(__n_total_submitted_);
421422
__result = __submission_queue_.submit(
422423
(__task_queue&&) __pending_, __max_submissions, __stop_source_->stop_requested());
423-
__n_submitted_ += __result.__n_submitted;
424-
STDEXEC_ASSERT(__n_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
424+
__n_total_submitted_ += __result.__n_submitted;
425+
__n_newly_submitted_ += __result.__n_submitted;
426+
STDEXEC_ASSERT(__n_total_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
425427
__pending_ = (__task_queue&&) __result.__pending;
426428
}
427429
}
@@ -446,28 +448,30 @@ namespace exec {
446448
__is_running_.store(false, std::memory_order_relaxed);
447449
}};
448450
__pending_.append(__requests_.pop_all());
449-
while (__n_submitted_ > 0 || !__pending_.empty()) {
451+
while (__n_total_submitted_ > 0 || !__pending_.empty()) {
450452
run_some();
451453
if (
452-
__n_submitted_ == 0
453-
|| (__n_submitted_ == 1 && __break_loop_.load(std::memory_order_acquire))) {
454+
__n_total_submitted_ == 0
455+
|| (__n_total_submitted_ == 1 && __break_loop_.load(std::memory_order_acquire))) {
454456
__break_loop_.store(false, std::memory_order_relaxed);
455457
break;
456458
}
457459
constexpr int __min_complete = 1;
458460
STDEXEC_ASSERT(
459-
0 <= __n_submitted_
460-
&& __n_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
461+
0 <= __n_total_submitted_
462+
&& __n_total_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
461463
int rc = __io_uring_enter(
462-
__ring_fd_, __n_submitted_, __min_complete, IORING_ENTER_GETEVENTS);
464+
__ring_fd_, __n_newly_submitted_, __min_complete, IORING_ENTER_GETEVENTS);
463465
__throw_error_code_if(rc < 0, -rc);
464-
__n_submitted_ -= __completion_queue_.complete();
465-
STDEXEC_ASSERT(0 <= __n_submitted_);
466+
STDEXEC_ASSERT(rc <= __n_newly_submitted_);
467+
__n_newly_submitted_ -= rc;
468+
__n_total_submitted_ -= __completion_queue_.complete();
469+
STDEXEC_ASSERT(0 <= __n_total_submitted_);
466470
__pending_.append(__requests_.pop_all());
467471
}
468-
STDEXEC_ASSERT(__n_submitted_ <= 1);
472+
STDEXEC_ASSERT(__n_total_submitted_ <= 1);
469473
if (__stop_source_->stop_requested() && __pending_.empty()) {
470-
STDEXEC_ASSERT(__n_submitted_ == 0);
474+
STDEXEC_ASSERT(__n_total_submitted_ == 0);
471475
// try to shutdown the request queue
472476
int __n_in_flight_expected = 0;
473477
while (!__n_submissions_in_flight_.compare_exchange_weak(
@@ -581,7 +585,8 @@ namespace exec {
581585
std::atomic<bool> __is_running_{false};
582586
std::atomic<int> __n_submissions_in_flight_{0};
583587
std::atomic<bool> __break_loop_{false};
584-
std::ptrdiff_t __n_submitted_{0};
588+
std::ptrdiff_t __n_total_submitted_{0};
589+
std::ptrdiff_t __n_newly_submitted_{0};
585590
std::optional<stdexec::in_place_stop_source> __stop_source_{std::in_place};
586591
__completion_queue __completion_queue_;
587592
__submission_queue __submission_queue_;
@@ -638,11 +643,11 @@ namespace exec {
638643
static constexpr __task_vtable __vtable{&__ready_, &__submit_, &__complete_};
639644

640645
template <class... _Args>
641-
requires stdexec::constructible_from<_Base, std::in_place_t, _Args...>
646+
requires stdexec::constructible_from<_Base, std::in_place_t, __task*, _Args...>
642647
__io_task_facade(std::in_place_t, _Args&&... __args) noexcept(
643-
stdexec::__nothrow_constructible_from<_Base, _Args...>)
648+
stdexec::__nothrow_constructible_from<_Base, __task*, _Args...>)
644649
: __task{__vtable}
645-
, __base_(std::in_place, (_Args&&) __args...) {
650+
, __base_(std::in_place, static_cast<__task*>(this), (_Args&&) __args...) {
646651
}
647652

648653
template <class... _Args>
@@ -731,8 +736,8 @@ namespace exec {
731736
__op_->submit_stop(__sqe);
732737
} else {
733738
__sqe = ::io_uring_sqe{
734-
.opcode = IORING_OP_ASYNC_CANCEL, //
735-
.addr = bit_cast<__u64>(__op_) //
739+
.opcode = IORING_OP_ASYNC_CANCEL, //
740+
.addr = bit_cast<__u64>(__op_->__parent_) //
736741
};
737742
}
738743
#else
@@ -768,23 +773,27 @@ namespace exec {
768773

769774
template <class _Base, bool _False>
770775
struct __impl_base {
776+
__task* __parent_;
771777
_Base __base_;
772778

773779
template <class... _Args>
774-
__impl_base(std::in_place_t, _Args&&... __args) noexcept(
780+
__impl_base(__task* __parent, std::in_place_t, _Args&&... __args) noexcept(
775781
stdexec::__nothrow_constructible_from<_Base, _Args...>)
776-
: __base_((_Args&&) __args...) {
782+
: __parent_{__parent}
783+
, __base_((_Args&&) __args...) {
777784
}
778785
};
779786

780787
template <class _Base>
781788
struct __impl_base<_Base, true> {
789+
__task* __parent_;
782790
_Base __base_;
783791

784792
template <class... _Args>
785-
__impl_base(std::in_place_t, _Args&&... __args) noexcept(
793+
__impl_base(__task* __parent, std::in_place_t, _Args&&... __args) noexcept(
786794
stdexec::__nothrow_constructible_from<_Base, _Args...>)
787-
: __base_((_Args&&) __args...) {
795+
: __parent_{__parent}
796+
, __base_((_Args&&) __args...) {
788797
}
789798

790799
void submit_stop(::io_uring_sqe& __sqe) noexcept {
@@ -823,9 +832,9 @@ namespace exec {
823832

824833
template <class... _Args>
825834
requires stdexec::constructible_from<_Base, _Args...>
826-
__impl(std::in_place_t, _Args&&... __args) noexcept(
835+
__impl(std::in_place_t, __task* __parent, _Args&&... __args) noexcept(
827836
stdexec::__nothrow_constructible_from<_Base, _Args...>)
828-
: __base_t(std::in_place, (_Args&&) __args...)
837+
: __base_t(__parent, std::in_place, (_Args&&) __args...)
829838
, __stop_operation_{this} {
830839
}
831840

test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ set(stdexec_test_sources
7878
exec/test_when_any.cpp
7979
exec/test_at_coroutine_exit.cpp
8080
exec/test_materialize.cpp
81-
exec/test_io_uring_context.cpp
81+
$<$<BOOL:${STDEXEC_ENABLE_IO_URING_TESTS}>:exec/test_io_uring_context.cpp>
8282
exec/test_trampoline_scheduler.cpp
8383
exec/test_sequence_senders.cpp
8484
exec/sequence/test_empty_sequence.cpp

test/exec/test_io_uring_context.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,23 @@ TEST_CASE("io_uring_context schedule_after -1s", "[types][io_uring][schedulers]"
301301
scope_guard guard{[&]() noexcept {
302302
context.request_stop();
303303
}};
304-
bool is_called = false;
304+
bool is_called_1 = false;
305+
bool is_called_2 = false;
306+
auto start = std::chrono::steady_clock::now();
307+
auto timeout = 100ms;
305308
sync_wait(when_any(
306309
schedule_after(scheduler, -1s) | then([&] {
307310
CHECK(io_thread.get_id() == std::this_thread::get_id());
308-
is_called = true;
311+
is_called_1 = true;
309312
}),
310-
schedule_after(scheduler, 5ms)));
311-
CHECK(is_called);
313+
schedule_after(scheduler, timeout) | then([&] {
314+
is_called_2 = true;
315+
})));
316+
auto end = std::chrono::steady_clock::now();
317+
std::chrono::nanoseconds diff = end - start;
318+
CHECK(diff.count() < std::chrono::duration_cast<std::chrono::nanoseconds>(timeout).count());
319+
CHECK(is_called_1 == true);
320+
CHECK(is_called_2 == false);
312321
}
313322
}
314323

0 commit comments

Comments
 (0)