Skip to content

Commit 493d41b

Browse files
🐛 Fix syncify with raise_sync_error=False on AnyIO 4.x.x, do not start new event loops unnecessarily (#130)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a51dbd6 commit 493d41b

File tree

4 files changed

+110
-4
lines changed

4 files changed

+110
-4
lines changed

.github/workflows/test.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ jobs:
2626
- "3.10"
2727
- "3.11"
2828
- "3.12"
29+
anyio-version:
30+
- anyio-v3
31+
- anyio-v4
2932
fail-fast: false
3033

3134
steps:
@@ -56,6 +59,12 @@ jobs:
5659
- name: Install Dependencies
5760
if: steps.cache.outputs.cache-hit != 'true'
5861
run: python -m poetry install
62+
- name: Install AnyIO v3
63+
if: matrix.anyio-version == 'anyio-v3'
64+
run: pip install --upgrade "anyio>=3.4.0,<4.0"
65+
- name: Install AnyIO v4
66+
if: matrix.anyio-version == 'anyio-v4'
67+
run: pip install --upgrade "anyio>=4.0.0,<5.0"
5968
- name: Lint
6069
run: python -m poetry run bash scripts/lint.sh
6170
- run: mkdir coverage

asyncer/_main.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626

2727
# This was obtained with: from anyio._core._eventloop import get_asynclib
2828
# Removed in https://github.com/agronholm/anyio/pull/429
29-
# First release (not released yet): 4.0-dev
29+
# Released in AnyIO 4.x.x
30+
# The new function is anyio._core._eventloop.get_async_backend but that returns a
31+
# class, not a module to extract the TaskGroup class from.
3032
def get_asynclib(asynclib_name: Union[str, None] = None) -> Any:
3133
if asynclib_name is None:
3234
asynclib_name = sniffio.current_async_library()
@@ -298,7 +300,12 @@ async def do_work(arg1, arg2, kwarg1="", kwarg2=""):
298300

299301
@functools.wraps(async_function)
300302
def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
301-
current_async_module = getattr(threadlocals, "current_async_module", None)
303+
current_async_module = (
304+
getattr(threadlocals, "current_async_backend", None)
305+
or
306+
# TODO: remove when deprecating AnyIO 3.x
307+
getattr(threadlocals, "current_async_module", None)
308+
)
302309
partial_f = functools.partial(async_function, *args, **kwargs)
303310
if current_async_module is None and raise_sync_error is False:
304311
return anyio.run(partial_f)

tests/test_syncify_no_raise.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import threading
2+
from dataclasses import dataclass
3+
from typing import List
4+
5+
import anyio
6+
from asyncer import asyncify, syncify
7+
8+
9+
@dataclass
10+
class Report:
11+
thread_id: int
12+
caller_func: str
13+
14+
15+
def test_syncify_no_raise_async():
16+
reports: List[Report] = []
17+
18+
async def do_sub_async_work():
19+
report = Report(
20+
thread_id=threading.get_ident(),
21+
caller_func="do_sub_async_work",
22+
)
23+
reports.append(report)
24+
25+
def do_sub_sync_work():
26+
report = Report(
27+
thread_id=threading.get_ident(),
28+
caller_func="do_sub_sync_work",
29+
)
30+
reports.append(report)
31+
syncify(do_sub_async_work, raise_sync_error=False)()
32+
33+
async def do_async_work():
34+
report = Report(
35+
thread_id=threading.get_ident(),
36+
caller_func="do_async_work",
37+
)
38+
reports.append(report)
39+
await asyncify(do_sub_sync_work)()
40+
41+
def do_sync_work():
42+
own_report = Report(
43+
thread_id=threading.get_ident(),
44+
caller_func="do_sync_work",
45+
)
46+
reports.append(own_report)
47+
syncify(do_async_work, raise_sync_error=False)()
48+
49+
async def main():
50+
own_report = Report(
51+
thread_id=threading.get_ident(),
52+
caller_func="main",
53+
)
54+
reports.append(own_report)
55+
await asyncify(do_sync_work)()
56+
57+
def sync_main():
58+
own_report = Report(
59+
thread_id=threading.get_ident(),
60+
caller_func="sync_main",
61+
)
62+
reports.append(own_report)
63+
do_sync_work()
64+
65+
anyio.run(main)
66+
sync_main()
67+
main_thread_id = threading.get_ident()
68+
assert reports[0].caller_func == "main"
69+
assert reports[0].thread_id == main_thread_id
70+
assert reports[1].caller_func == "do_sync_work"
71+
assert reports[1].thread_id != main_thread_id
72+
assert reports[2].caller_func == "do_async_work"
73+
assert reports[2].thread_id == main_thread_id
74+
assert reports[3].caller_func == "do_sub_sync_work"
75+
assert reports[3].thread_id != main_thread_id
76+
assert reports[4].caller_func == "do_sub_async_work"
77+
assert reports[4].thread_id == main_thread_id
78+
assert reports[5].caller_func == "sync_main"
79+
assert reports[5].thread_id == main_thread_id
80+
assert reports[6].caller_func == "do_sync_work"
81+
assert reports[6].thread_id == main_thread_id
82+
assert reports[7].caller_func == "do_async_work"
83+
assert reports[7].thread_id == main_thread_id
84+
assert reports[8].caller_func == "do_sub_sync_work"
85+
assert reports[8].thread_id != main_thread_id
86+
assert reports[9].caller_func == "do_sub_async_work"
87+
assert reports[9].thread_id == main_thread_id

tests/test_tutorial/test_soonify_return/test_tutorial002.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@ def test_tutorial():
1616
new_print = get_testing_print_function(calls)
1717

1818
with patch("builtins.print", new=new_print):
19-
with pytest.raises(ExceptionGroup) as e:
19+
with pytest.raises((ExceptionGroup, asyncer.PendingValueException)) as e:
2020
from docs_src.tutorial.soonify_return import tutorial002 as mod
2121

2222
# Avoid autoflake removing this import
2323
assert mod # pragma: nocover
24-
assert isinstance(e.value.exceptions[0], asyncer.PendingValueException)
24+
if isinstance(e.value, ExceptionGroup):
25+
assert isinstance(e.value.exceptions[0], asyncer.PendingValueException)
26+
else:
27+
assert isinstance(e.value, asyncer.PendingValueException)
2528
assert calls == []

0 commit comments

Comments
 (0)