Skip to content
21 changes: 17 additions & 4 deletions docs/source/reference-lowlevel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ Spawning threads
.. autofunction:: start_thread_soon


.. _ki-handling:

Safer KeyboardInterrupt handling
================================

Expand All @@ -355,10 +357,21 @@ correctness invariants. On the other, if the user accidentally writes
an infinite loop, we do want to be able to break out of that. Our
solution is to install a default signal handler which checks whether
it's safe to raise :exc:`KeyboardInterrupt` at the place where the
signal is received. If so, then we do; otherwise, we schedule a
:exc:`KeyboardInterrupt` to be delivered to the main task at the next
available opportunity (similar to how :exc:`~trio.Cancelled` is
delivered).
signal is received. If so, then we do. Otherwise, we cancel all tasks
and add `KeyboardInterrupt` as the result of :func:`trio.run`.

.. note:: This behavior means it's not a good idea to try to catch
`KeyboardInterrupt` within a Trio task. Most Trio
programs are I/O-bound, so most interrupts will be received while
no task is running (because Trio is waiting for I/O). There's no
task that should obviously receive the interrupt in such cases, so
Trio doesn't raise it within a task at all: every task gets cancelled,
then `KeyboardInterrupt` is raised once that's complete.

If you want to handle Ctrl+C by doing something other than "cancel
all tasks", then you should use :func:`~trio.open_signal_receiver` to
install a handler for `signal.SIGINT`. If you do that, then Ctrl+C will
go to your handler, and it can do whatever it wants.

So that's great, but – how do we know whether we're in one of the
sensitive parts of the program or not?
Expand Down
42 changes: 42 additions & 0 deletions newsfragments/733.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
:ref:`Sometimes <ki-handling>`, a Trio program receives an interrupt
signal (Ctrl+C) at a time when Python's default response (raising
`KeyboardInterrupt` immediately) might corrupt Trio's internal
state. Previously, Trio would handle this situation by raising the
`KeyboardInterrupt` at the next :ref:`checkpoint <checkpoints>` executed
by the main task (the one running the function you passed to :func:`trio.run`).
This was responsible for a lot of internal complexity and sometimes led to
surprising behavior.

With this release, such a "deferred" `KeyboardInterrupt` is handled in a
different way: Trio will first cancel all running tasks, then raise
`KeyboardInterrupt` directly out of the call to :func:`trio.run`.
The difference is relevant if you have code that tries to catch
`KeyboardInterrupt` within Trio. This was never entirely robust, but it
previously might have worked in many cases, whereas now it will never
catch the interrupt.

An example of code that mostly worked on previous releases, but won't
work on this release::

async def main():
try:
await trio.sleep_forever()
except KeyboardInterrupt:
print("interrupted")
trio.run(main)

The fix is to catch `KeyboardInterrupt` outside Trio::

async def main():
await trio.sleep_forever()
try:
trio.run(main)
except KeyboardInterrupt:
print("interrupted")

If that doesn't work for you (because you want to respond to
`KeyboardInterrupt` by doing something other than cancelling all
tasks), then you can start a task that uses
`trio.open_signal_receiver` to receive the interrupt signal ``SIGINT``
directly and handle it however you wish. Such a task takes precedence
over Trio's default interrupt handling.
36 changes: 7 additions & 29 deletions src/trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,7 @@ class Runner: # type: ignore[explicit-any]
system_context: contextvars.Context = attrs.field(kw_only=True)
main_task: Task | None = None
main_task_outcome: Outcome[object] | None = None
main_task_nursery: Nursery | None = None

entry_queue: EntryQueue = attrs.Factory(EntryQueue)
trio_token: TrioToken | None = None
Expand Down Expand Up @@ -2055,12 +2056,12 @@ async def init(
# All other system tasks run here:
async with open_nursery() as self.system_nursery:
# Only the main task runs here:
async with open_nursery() as main_task_nursery:
async with open_nursery() as self.main_task_nursery:
try:
self.main_task = self.spawn_impl(
async_fn,
args,
main_task_nursery,
self.main_task_nursery,
None,
)
except BaseException as exc:
Expand Down Expand Up @@ -2105,30 +2106,13 @@ def current_trio_token(self) -> TrioToken:

ki_pending: bool = False

# deliver_ki is broke. Maybe move all the actual logic and state into
# RunToken, and we'll only have one instance per runner? But then we can't
# have a public constructor. Eh, but current_run_token() returning a
# unique object per run feels pretty nice. Maybe let's just go for it. And
# keep the class public so people can isinstance() it if they want.

# This gets called from signal context
def deliver_ki(self) -> None:
self.ki_pending = True
with suppress(RunFinishedError):
self.entry_queue.run_sync_soon(self._deliver_ki_cb)
assert self.main_task_nursery is not None

def _deliver_ki_cb(self) -> None:
if not self.ki_pending:
return
# Can't happen because main_task and run_sync_soon_task are created at
# the same time -- so even if KI arrives before main_task is created,
# we won't get here until afterwards.
assert self.main_task is not None
if self.main_task_outcome is not None:
# We're already in the process of exiting -- leave ki_pending set
# and we'll check it again on our way out of run().
return
self.main_task._attempt_delivery_of_pending_ki()
with suppress(RunFinishedError):
self.entry_queue.run_sync_soon(self.main_task_nursery.cancel_scope.cancel)

################
# Quiescing
Expand Down Expand Up @@ -2787,10 +2771,6 @@ def unrolled_run(
elif type(msg) is WaitTaskRescheduled:
task._cancel_points += 1
task._abort_func = msg.abort_func
# KI is "outside" all cancel scopes, so check for it
# before checking for regular cancellation:
if runner.ki_pending and task is runner.main_task:
task._attempt_delivery_of_pending_ki()
task._attempt_delivery_of_any_pending_cancel()
elif type(msg) is PermanentlyDetachCoroutineObject:
# Pretend the task just exited with the given outcome
Expand Down Expand Up @@ -2923,9 +2903,7 @@ async def checkpoint() -> None:
await cancel_shielded_checkpoint()
task = current_task()
task._cancel_points += 1
if task._cancel_status.effectively_cancelled or (
task is task._runner.main_task and task._runner.ki_pending
):
if task._cancel_status.effectively_cancelled:
with CancelScope(deadline=-inf):
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)

Expand Down
3 changes: 2 additions & 1 deletion src/trio/_core/_tests/test_guest_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,8 @@ async def trio_main(in_host: InHost) -> None:

with pytest.raises(KeyboardInterrupt) as excinfo:
trivial_guest_run(trio_main)
assert excinfo.value.__context__ is None
assert isinstance(excinfo.value.__context__, trio.Cancelled)
assert excinfo.value.__context__.__context__ is None
# Signal handler should be restored properly on exit
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler

Expand Down
55 changes: 22 additions & 33 deletions src/trio/_core/_tests/test_ki.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ async def check_unprotected_kill() -> None:
_core.run(check_unprotected_kill)
assert record_set == {"s1 ok", "s2 ok", "r1 raise ok"}

# simulated control-C during raiser, which is *protected*, so the KI gets
# delivered to the main task instead
# simulated control-C during raiser, which is *protected*, so the run
# gets cancelled instead.
print("check 2")
record_set = set()

Expand All @@ -325,9 +325,12 @@ async def check_protected_kill() -> None:
nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record_set)
# __aexit__ blocks, and then receives the KI

# raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup
with RaisesGroup(KeyboardInterrupt):
# KeyboardInterrupt is inserted from the trio.run
with pytest.raises(KeyboardInterrupt) as excinfo:
_core.run(check_protected_kill)

# TODO: be consistent about providing Cancelled tree as __context__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's required for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the issue here is that in the above test, no __context__ tree is kept which is wrong. Maybe.

But now that I think about it more, I think this might be fine? If a task is busy-looping it'll raise a KeyboardInterrupt too. I don't really like the Cancelled tree so maybe we should do the opposite of this and try to eliminate it! I'll update this to have this written down. (I'd rather know whether this tree of cancellation is actually a burden but I suspect it'll add many lines to tracebacks in event of ctrl+C for not much use)

assert excinfo.value.__context__ is None
assert record_set == {"s1 ok", "s2 ok", "r1 cancel ok"}

# kill at last moment still raises (run_sync_soon until it raises an
Expand Down Expand Up @@ -373,10 +376,11 @@ async def main_1() -> None:
async def main_2() -> None:
assert _core.currently_ki_protected()
ki_self()
with pytest.raises(KeyboardInterrupt):
with pytest.raises(_core.Cancelled):
await _core.checkpoint_if_cancelled()

_core.run(main_2)
with pytest.raises(KeyboardInterrupt):
_core.run(main_2)

# KI arrives while main task is not abortable, b/c already scheduled
print("check 6")
Expand All @@ -388,10 +392,11 @@ async def main_3() -> None:
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
with pytest.raises(KeyboardInterrupt):
with pytest.raises(_core.Cancelled):
await _core.checkpoint()

_core.run(main_3)
with pytest.raises(KeyboardInterrupt):
_core.run(main_3)

# KI arrives while main task is not abortable, b/c refuses to be aborted
print("check 7")
Expand All @@ -407,10 +412,11 @@ def abort(_: RaiseCancelT) -> Abort:
return _core.Abort.FAILED

assert await _core.wait_task_rescheduled(abort) == 1
with pytest.raises(KeyboardInterrupt):
with pytest.raises(_core.Cancelled):
await _core.checkpoint()

_core.run(main_4)
with pytest.raises(KeyboardInterrupt):
_core.run(main_4)

# KI delivered via slow abort
print("check 8")
Expand All @@ -426,11 +432,12 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
_core.reschedule(task, result)
return _core.Abort.FAILED

with pytest.raises(KeyboardInterrupt):
with pytest.raises(_core.Cancelled):
assert await _core.wait_task_rescheduled(abort)
await _core.checkpoint()

_core.run(main_5)
with pytest.raises(KeyboardInterrupt):
_core.run(main_5)

# KI arrives just before main task exits, so the run_sync_soon machinery
# is still functioning and will accept the callback to deliver the KI, but
Expand All @@ -457,10 +464,11 @@ async def main_7() -> None:
# ...but even after the KI, we keep running uninterrupted...
record_list.append("ok")
# ...until we hit a checkpoint:
with pytest.raises(KeyboardInterrupt):
with pytest.raises(_core.Cancelled):
await sleep(10)

_core.run(main_7, restrict_keyboard_interrupt_to_checkpoints=True)
with pytest.raises(KeyboardInterrupt):
_core.run(main_7, restrict_keyboard_interrupt_to_checkpoints=True)
assert record_list == ["ok"]
record_list = []
# Exact same code raises KI early if we leave off the argument, doesn't
Expand All @@ -469,25 +477,6 @@ async def main_7() -> None:
_core.run(main_7)
assert record_list == []

# KI arrives while main task is inside a cancelled cancellation scope
# the KeyboardInterrupt should take priority
print("check 11")

@_core.enable_ki_protection
async def main_8() -> None:
assert _core.currently_ki_protected()
with _core.CancelScope() as cancel_scope:
cancel_scope.cancel()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()

_core.run(main_8)


def test_ki_is_good_neighbor() -> None:
# in the unlikely event someone overwrites our signal handler, we leave
Expand Down
44 changes: 33 additions & 11 deletions src/trio/_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@

import trio
import trio.lowlevel
from trio._core._run_context import GLOBAL_RUN_CONTEXT
from trio._util import final


@final
class TrioInteractiveConsole(InteractiveConsole):
runner: trio._core._run.Runner | None

def __init__(self, repl_locals: dict[str, object] | None = None) -> None:
super().__init__(locals=repl_locals)
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
self.runner = None

def runcode(self, code: types.CodeType) -> None:
# https://github.com/python/typeshed/issues/13768
Expand All @@ -28,6 +32,17 @@ def runcode(self, code: types.CodeType) -> None:
result = trio.from_thread.run(outcome.acapture, func)
else:
result = trio.from_thread.run_sync(outcome.capture, func)

# clear ki_pending
assert self.runner is not None
ki_pending = self.runner.ki_pending
self.runner.ki_pending = False

if ki_pending:
exc: BaseException | None = KeyboardInterrupt()
else:
exc = None

if isinstance(result, outcome.Error):
# If it is SystemExit, quit the repl. Otherwise, print the traceback.
# If there is a SystemExit inside a BaseExceptionGroup, it probably isn't
Expand All @@ -37,21 +52,28 @@ def runcode(self, code: types.CodeType) -> None:
if isinstance(result.error, SystemExit):
raise result.error
else:
# Inline our own version of self.showtraceback that can use
# outcome.Error.error directly to print clean tracebacks.
# This also means overriding self.showtraceback does nothing.
sys.last_type, sys.last_value = type(result.error), result.error
sys.last_traceback = result.error.__traceback__
# see https://docs.python.org/3/library/sys.html#sys.last_exc
if sys.version_info >= (3, 12):
sys.last_exc = result.error
if exc:
exc.__context__ = result.error
else:
exc = result.error

if exc is not None:
# Inline our own version of self.showtraceback that can use
# outcome.Error.error directly to print clean tracebacks.
# This also means overriding self.showtraceback does nothing.
sys.last_type, sys.last_value = type(exc), exc
sys.last_traceback = exc.__traceback__
# see https://docs.python.org/3/library/sys.html#sys.last_exc
if sys.version_info >= (3, 12):
sys.last_exc = exc

# We always use sys.excepthook, unlike other implementations.
# This means that overriding self.write also does nothing to tbs.
sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback)
# We always use sys.excepthook, unlike other implementations.
# This means that overriding self.write also does nothing to tbs.
sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback)


async def run_repl(console: TrioInteractiveConsole) -> None:
console.runner = GLOBAL_RUN_CONTEXT.runner
banner = (
f"trio REPL {sys.version} on {sys.platform}\n"
f'Use "await" directly instead of "trio.run()".\n'
Expand Down
2 changes: 2 additions & 0 deletions src/trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ async def to_thread_run_sync(
if limiter is None:
limiter = current_default_thread_limiter()

# TODO: cancel_register can probably be a single element tuple for typing reasons

# Holds a reference to the task that's blocked in this function waiting
# for the result – or None if this function was cancelled and we should
# discard the result.
Expand Down