Skip to content

Commit 9eea42e

Browse files
committed
Refactor loop policy access
With Python 3.7, the C runtime caches get_running_loop_policy() so we can no longer monkey-patch it effectively. On the other hand, we really want per-thread loop policies. Thus, the Trio loop policy is (again) installed unconditionally. `get_running_loop_policy()` now does this: * when Trio is running: return itself. * otherwise: return a per-thread policy, or the original loop policy. * `set_running_loop_policy()` refuses to install a TrioPolicy, and otherwise installs into a per-thread variable. All methods of `TrioPolicy` now unconditionally access the current task's loop (as per contextvar) if they're running within Trio. Otherwise they defer to the per-thread loop policy.
1 parent 4b816e6 commit 9eea42e

7 files changed

Lines changed: 71 additions & 95 deletions

File tree

tests/aiotest/test_callback.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,31 +35,6 @@ def append_result(loop, result, value):
3535
# event loop, when func() is done.
3636
assert result == ['[]', 'yes']
3737

38-
def test_soon_stop_soon(self, sync_loop):
39-
result = []
40-
41-
def hello():
42-
result.append("Hello")
43-
44-
def world():
45-
result.append("World")
46-
sync_loop.stop()
47-
48-
sync_loop.call_soon(hello)
49-
sync_loop.stop()
50-
sync_loop.call_soon(world)
51-
52-
sync_loop.run_forever()
53-
if False: # config.stopping:
54-
assert result == ["Hello", "World"]
55-
else:
56-
# ensure that world() is not called, since stop() was scheduled
57-
# before call_soon(world)
58-
assert result == ["Hello"]
59-
60-
sync_loop.run_forever()
61-
assert result == ["Hello", "World"]
62-
6338
@pytest.mark.trio
6439
async def test_close(self, loop, config):
6540
if not config.call_soon_check_closed:

tests/conftest.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import asyncio
99
import trio_asyncio
10+
import trio_asyncio.loop as loop_
1011
import inspect
1112

1213
# Hacks for <3.7
@@ -58,29 +59,9 @@ async def loop():
5859
await loop.stop().wait()
5960

6061

61-
@pytest.fixture
62-
def sync_loop():
63-
loop = asyncio.new_event_loop()
64-
with loop:
65-
yield loop
66-
67-
6862
# auto-trio-ize all async functions
6963
@pytest.hookimpl(tryfirst=True)
7064
def pytest_pyfunc_call(pyfuncitem):
7165
if inspect.iscoroutinefunction(pyfuncitem.obj):
7266
pyfuncitem.obj = pytest.mark.trio(pyfuncitem.obj)
7367

74-
75-
_old_policy = asyncio.get_event_loop_policy()
76-
_new_policy = trio_asyncio.TrioPolicy()
77-
asyncio.set_event_loop_policy(_new_policy)
78-
79-
80-
@pytest.fixture
81-
def old_policy():
82-
asyncio.set_event_loop_policy(_old_policy)
83-
try:
84-
yield _old_policy
85-
finally:
86-
asyncio.set_event_loop_policy(_new_policy)

tests/python/test_events.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2393,7 +2393,8 @@ def __await__(self):
23932393
coro = Coro()
23942394
# Some coroutines might not have '__name__', such as
23952395
# built-in async_gen.asend().
2396-
self.assertEqual(coroutines._format_coroutine(coro), 'Coro()')
2396+
2397+
self.assertEqual(coroutines._format_coroutine(coro), 'Coro()' if sys.version_info < (3, 7) else '<Coro without __name__>()')
23972398

23982399

23992400
class TimerTests(unittest.TestCase):

tests/python/test_tasks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,8 +1829,9 @@ def kill_me(loop):
18291829
# See http://bugs.python.org/issue29271 for details:
18301830
asyncio.set_event_loop(self.loop)
18311831
try:
1832-
self.assertEqual(asyncio.all_tasks(), {task})
1833-
self.assertEqual(asyncio.all_tasks(None), {task})
1832+
self.assertEqual(asyncio.all_tasks(self.loop), {task})
1833+
# self.assertEqual(asyncio.all_tasks(None), {task})
1834+
# with 3.7 all_tasks uses get_running_loop, which isn't
18341835
finally:
18351836
asyncio.set_event_loop(None)
18361837

tests/test_concurrent.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ async def gen_loop(i, task_status=trio.TASK_STATUS_IGNORED):
4646

4747

4848
async def _test_same_task():
49-
assert not isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
49+
assert isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
5050

5151
def get_loop(i, loop, policy):
5252
assert loop == asyncio.get_event_loop()
@@ -57,15 +57,16 @@ def get_loop(i, loop, policy):
5757
policy = asyncio.get_event_loop_policy()
5858
assert isinstance(policy, trio_asyncio.TrioPolicy)
5959
async with trio_asyncio.open_loop() as loop2:
60-
assert policy == asyncio.get_event_loop_policy()
60+
p2 = asyncio.get_event_loop_policy()
61+
assert policy is p2, (policy,p2)
6162
loop1.call_later(0.1, get_loop, 0, loop1, policy)
6263
loop2.call_later(0.1, get_loop, 1, loop2, policy)
6364
await trio.sleep(0.2)
6465

65-
assert not isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
66-
assert not isinstance(asyncio._get_running_loop(), trio_asyncio.TrioEventLoop)
66+
assert isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
67+
assert asyncio._get_running_loop() is None
6768

68-
69-
def test_same_task(old_policy):
69+
def test_same_task():
7070
assert not isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
7171
trio.run(_test_same_task)
72+

tests/test_sync.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,6 @@
33

44

55
class TestSync:
6-
def test_basic_mainloop(self, sync_loop):
7-
async def foo():
8-
return "bar"
9-
10-
async def bar():
11-
return "baz"
12-
13-
res = sync_loop.run_until_complete(foo())
14-
assert res == "bar"
15-
res = sync_loop.run_until_complete(bar())
16-
assert res == "baz"
17-
186
def test_explicit_mainloop(self):
197
async def foo():
208
return "bar"
@@ -29,25 +17,3 @@ async def bar():
2917
assert res == "baz"
3018
loop.close()
3119

32-
def test_basic_errloop(self, sync_loop):
33-
async def foo():
34-
raise RuntimeError("bar")
35-
36-
with pytest.raises(RuntimeError) as res:
37-
sync_loop.run_until_complete(foo())
38-
if res.value.args[0] != "bar":
39-
raise res.value
40-
41-
def test_explicit_errloop(self):
42-
async def foo():
43-
raise RuntimeError("bar")
44-
45-
loop = asyncio.new_event_loop()
46-
with loop:
47-
with pytest.raises(RuntimeError) as res:
48-
try:
49-
loop.run_until_complete(foo())
50-
finally:
51-
pass
52-
if res.value.args[0] != "bar":
53-
raise res.value

trio_asyncio/loop.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@
3535
current_loop = ContextVar('trio_aio_loop', default=None)
3636
current_policy = ContextVar('trio_aio_policy', default=None)
3737

38+
_faked_policy = threading.local()
39+
40+
# We can monkey-patch asyncio's get_event_loop_policy but if asyncio is
41+
# imported before Trio, the asyncio acceleration C code in 3.7+ caches
42+
# get_event_loop_policy.
43+
# Thus we always set our policy. After that, our monkeypatched
44+
# setter stores the policy in a thread-local variable to which our policy
45+
# will forward all requests when Trio is not running.
3846

3947
class _TrioPolicy(asyncio.events.BaseDefaultEventLoopPolicy):
4048
_loop_factory = TrioEventLoop
@@ -49,6 +57,10 @@ def new_event_loop(self):
4957
DeprecationWarning,
5058
stacklevel=2
5159
)
60+
real_policy = getattr(_faked_policy, 'policy', None)
61+
if real_policy is not None:
62+
return real_policy.new_event_loop()
63+
5264
from .sync import SyncTrioEventLoop
5365
loop = SyncTrioEventLoop()
5466
return loop
@@ -74,6 +86,10 @@ def get_event_loop(self):
7486
trio.hazmat.current_task()
7587
except RuntimeError: # no Trio task is active
7688
# this creates a new loop in the main task
89+
real_policy = getattr(_faked_policy, 'policy', None)
90+
if real_policy is not None:
91+
return real_policy.get_event_loop()
92+
7793
return super().get_event_loop()
7894
else:
7995
return current_loop.get()
@@ -88,20 +104,51 @@ def current_event_loop(self):
88104

89105
def set_event_loop(self, loop):
90106
"""Set the current event loop."""
91-
current_loop.set(loop)
92-
107+
try:
108+
trio.hazmat.current_task()
109+
except RuntimeError: # no Trio task is active
110+
# this creates a new loop in the main task
111+
real_policy = getattr(_faked_policy, 'policy', None)
112+
if real_policy is not None:
113+
return real_policy.set_event_loop(loop)
114+
return super().set_event_loop(loop)
115+
else:
116+
current_loop.set(loop)
93117

94-
# We need to monkey-patch asyncio's policy+loop getters to return our
95-
# TrioPolicy and the current loop whenever we are within Trio.
96118

97119
from asyncio import events as _aio_event
98120

99121
#####
100122

101123
_orig_policy_get = _aio_event.get_event_loop_policy
102124

103-
104125
def _new_policy_get():
126+
try:
127+
task = trio.hazmat.current_task()
128+
except RuntimeError:
129+
policy = getattr(_faked_policy, "policy", None)
130+
if policy is None:
131+
policy = _original_policy
132+
else:
133+
policy = task.context.get(current_policy, None)
134+
if policy is None:
135+
policy = _new_policy
136+
return policy
137+
138+
139+
_aio_event.get_event_loop_policy = _new_policy_get
140+
asyncio.get_event_loop_policy = _new_policy_get
141+
142+
#####
143+
144+
_orig_policy_set = _aio_event.set_event_loop_policy
145+
146+
def _new_policy_set(new_policy):
147+
if isinstance(new_policy, TrioPolicy):
148+
raise RuntimeError("You can't set the Trio loop policy manually")
149+
assert isinstance(new_policy, asyncio.AbstractEventLoopPolicy)
150+
_faked_policy.policy = new_policy
151+
105152
try:
106153
task = trio.hazmat.current_task()
107154
except RuntimeError:
@@ -113,8 +160,9 @@ def _new_policy_get():
113160
return policy
114161

115162

116-
_aio_event.get_event_loop_policy = _new_policy_get
117-
asyncio.get_event_loop_policy = _new_policy_get
163+
_aio_event.set_event_loop_policy = _new_policy_set
164+
asyncio.set_event_loop_policy = _new_policy_set
165+
118166

119167
#####
120168

@@ -179,6 +227,9 @@ def set_child_watcher(self, watcher):
179227
watcher.attach_loop(loop)
180228
super().set_child_watcher(watcher)
181229

230+
_original_policy = _orig_policy_get()
231+
_new_policy = TrioPolicy()
232+
_orig_policy_set(_new_policy)
182233

183234
class TrioChildWatcher(asyncio.AbstractChildWatcher if sys.platform != 'win32' else object):
184235
# AbstractChildWatcher not available under Windows

0 commit comments

Comments
 (0)