Skip to content

Commit f865772

Browse files
committed
Nested / parallel loops
Monkey-patch asyncio's get_event_loop{,_policy} calls to return the currently-running loop / Trio event policy whenever we're within a Trio task.
1 parent 4e7bc39 commit f865772

5 files changed

Lines changed: 122 additions & 17 deletions

File tree

tests/conftest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,15 @@ def pytest_pyfunc_call(pyfuncitem):
4848
pyfuncitem.obj = pytest.mark.trio(pyfuncitem.obj)
4949

5050

51-
asyncio.set_event_loop_policy(trio_asyncio.TrioPolicy())
51+
_old_policy = asyncio.get_event_loop_policy()
52+
_new_policy = trio_asyncio.TrioPolicy()
53+
asyncio.set_event_loop_policy(_new_policy)
54+
55+
@pytest.fixture
56+
def old_policy():
57+
asyncio.set_event_loop_policy(_old_policy)
58+
try:
59+
yield _old_policy
60+
finally:
61+
asyncio.set_event_loop_policy(_new_policy)
62+

tests/test_concurrent.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import trio
2+
import trio_asyncio
3+
import asyncio
4+
import pytest
5+
6+
# Tests for concurrent or nested loops
7+
8+
@pytest.mark.trio
9+
async def test_parallel():
10+
loops = [None, None]
11+
async with trio.open_nursery() as n:
12+
async def gen_loop(i, task_status=trio.TASK_STATUS_IGNORED):
13+
task_status.started()
14+
async with trio_asyncio.open_loop() as loop:
15+
loops[i] = loop
16+
17+
assert not isinstance(asyncio._get_running_loop(), trio_asyncio.TrioEventLoop)
18+
await n.start(gen_loop, 0)
19+
await n.start(gen_loop, 1)
20+
21+
assert isinstance(loops[0], trio_asyncio.TrioEventLoop)
22+
assert isinstance(loops[1], trio_asyncio.TrioEventLoop)
23+
assert loops[0] is not loops[1]
24+
25+
@pytest.mark.trio
26+
async def test_nested():
27+
loops = [None, None]
28+
async with trio.open_nursery() as n:
29+
async def gen_loop(i, task_status=trio.TASK_STATUS_IGNORED):
30+
task_status.started()
31+
async with trio_asyncio.open_loop() as loop:
32+
loops[i] = loop
33+
if i > 0:
34+
await n.start(gen_loop, i-1)
35+
36+
assert not isinstance(asyncio._get_running_loop(), trio_asyncio.TrioEventLoop)
37+
await n.start(gen_loop, 1)
38+
assert not isinstance(asyncio._get_running_loop(), trio_asyncio.TrioEventLoop)
39+
assert isinstance(loops[0], trio_asyncio.TrioEventLoop)
40+
assert isinstance(loops[1], trio_asyncio.TrioEventLoop)
41+
assert loops[0] is not loops[1]
42+
43+
async def _test_same_task():
44+
loops = [None, None]
45+
assert isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
46+
def get_loop(i):
47+
loops[i] = (asyncio.get_event_loop(), asyncio.get_event_loop_policy())
48+
async with trio.open_nursery() as n:
49+
async with trio_asyncio.open_loop() as loop1:
50+
async with trio_asyncio.open_loop() as loop2:
51+
loop1.call_later(0.1, get_loop, 0)
52+
loop2.call_later(0.1, get_loop, 1)
53+
await trio.sleep(0.2)
54+
55+
assert isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
56+
assert not isinstance(asyncio._get_running_loop(), trio_asyncio.TrioEventLoop)
57+
assert isinstance(loops[0][0], trio_asyncio.TrioEventLoop)
58+
assert isinstance(loops[1][0], trio_asyncio.TrioEventLoop)
59+
assert isinstance(loops[1][1], trio_asyncio.TrioPolicy)
60+
assert loops[0][0] is not loops[1][0]
61+
assert loops[0][1] is loops[1][1]
62+
assert loops[0][1] is asyncio.get_event_loop_policy()
63+
64+
def test_same_task(old_policy):
65+
assert not isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
66+
trio.run(_test_same_task)

trio_asyncio/async_.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,12 @@ def _main_loop_exit(self):
105105
super()._main_loop_exit()
106106
self._thread = None
107107

108+
from .loop import _current_loop
109+
outer_loop = _current_loop.loop
110+
108111
async with trio.open_nursery() as nursery:
109112
loop = TrioEventLoop(queue_len=queue_len)
113+
_current_loop.loop = loop
110114
try:
111115
loop._closed = False
112116
await loop._main_loop_init(nursery)
@@ -119,3 +123,5 @@ def _main_loop_exit(self):
119123
await loop._main_loop_exit()
120124
loop.close()
121125
nursery.cancel_scope.cancel()
126+
_current_loop.loop = outer_loop
127+

trio_asyncio/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,6 @@ async def _main_loop_init(self, nursery):
646646
self._nursery = nursery
647647
self._task = trio.hazmat.current_task()
648648
self._token = trio.hazmat.current_trio_token()
649-
asyncio.events._set_running_loop(self)
650649

651650
async def _main_loop(self, task_status=trio.TASK_STATUS_IGNORED):
652651
"""Run the loop by processing its event queue.
@@ -726,7 +725,6 @@ async def _main_loop_exit(self):
726725
# clean core fields
727726
self._nursery = None
728727
self._task = None
729-
asyncio.events._set_running_loop(None)
730728

731729
def is_running(self):
732730
if self._stopped is None:

trio_asyncio/loop.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@
3030
'TrioPolicy',
3131
]
3232

33+
_current_loop = trio.TaskLocal(loop=None, policy=None)
3334

3435
class _TrioPolicy(asyncio.events.BaseDefaultEventLoopPolicy):
3536
_loop_factory = TrioEventLoop
3637

3738
def __init__(self):
3839
super().__init__()
39-
self._trio_local = trio.TaskLocal(_loop=None, _task=False)
40+
current_loop = trio.TaskLocal(_loop=None, _task=False)
4041

4142
def new_event_loop(self):
4243
try:
@@ -75,45 +76,68 @@ def get_event_loop(self):
7576
# this creates a new loop in the main task
7677
return super().get_event_loop()
7778
else:
78-
return self._trio_local._loop
79+
return _current_loop.loop
7980

8081
@property
8182
def current_event_loop(self):
8283
"""The currently-running event loop, if one exists."""
8384
try:
84-
return self._trio_local._loop
85+
return _current_loop.loop
8586
except RuntimeError:
8687
# in the main thread this would create a new loop
8788
# return super().get_event_loop()
88-
return self._local._loop
89+
return super().get_event_loop()
8990

9091
def set_event_loop(self, loop):
9192
"""Set the current event loop."""
9293
try:
93-
task = trio.hazmat.current_task()
94+
_current_loop.loop = loop
9495
except RuntimeError:
9596
return super().set_event_loop(loop)
9697

97-
# This test will not trigger if you create a new asyncio event loop
98-
# in a sub-task, which is exactly what we intend to be possible
99-
if self._trio_local._loop is not None and loop is not None and \
100-
self._trio_local._task == task:
101-
raise RuntimeError('You cannot replace an event loop.', self._trio_local._loop, loop)
102-
self._trio_local._loop = loop
103-
self._trio_local._task = task
98+
99+
100+
# We need to monkey-patch asyncio's policy+loop getters to return our
101+
# TrioPolicy+loop whenever we are within Trio.
102+
103+
from asyncio import events as _aio_event
104+
105+
_orig_policy_get = _aio_event.get_event_loop_policy
106+
def _new_policy_get():
107+
try:
108+
policy = _current_loop.policy
109+
except RuntimeError:
110+
return _orig_policy_get()
111+
112+
if policy is None:
113+
policy = TrioPolicy()
114+
_current_loop.policy = policy
115+
return policy
116+
_aio_event.get_event_loop_policy = _new_policy_get
117+
asyncio.get_event_loop_policy = _new_policy_get
118+
119+
_orig_loop_get = _aio_event._get_running_loop
120+
def _new_loop_get():
121+
try:
122+
return _current_loop.loop
123+
except RuntimeError:
124+
return _orig_loop_get()
125+
_aio_event._get_running_loop = _new_loop_get
104126

105127

106128
class TrioPolicy(_TrioPolicy, asyncio.DefaultEventLoopPolicy):
129+
"""This is the loop policy that's active whenever we're in a Trio context."""
130+
107131
def _init_watcher(self):
108132
with asyncio.events._lock:
109133
if self._watcher is None: # pragma: no branch
110134
self._watcher = TrioChildWatcher()
111135
if isinstance(threading.current_thread(), threading._MainThread):
112-
self._watcher.attach_loop(self._trio_local._loop)
136+
self._watcher.attach_loop(_current_loop.loop)
113137

114138
if self._watcher is not None and \
115139
isinstance(threading.current_thread(), threading._MainThread):
116-
self._watcher.attach_loop(self._trio_local._loop)
140+
self._watcher.attach_loop(_current_loop.loop)
117141

118142
def set_child_watcher(self, watcher):
119143
if watcher is not None:

0 commit comments

Comments
 (0)