Skip to content

Commit ff22786

Browse files
committed
Use contextvars instead of TaskLocal.
In order to make this code work correctly ``` def get_loop(loop): assert loop == asyncio.get_event_loop() async with trio.open_nursery(): async with trio_asyncio.open_loop() as loop1: async with trio_asyncio.open_loop() as loop2: loop1.call_soon(get_loop, loop1) ``` the context values for the current task+policy are taken from the currently-running task (as ``get_loop`` runs, this is the task running ``loop1``), instead of from the context of the caller, which would result in ``loop2``.
1 parent 97dd2dd commit ff22786

4 files changed

Lines changed: 51 additions & 46 deletions

File tree

tests/test_concurrent.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,27 +46,24 @@ async def gen_loop(i, task_status=trio.TASK_STATUS_IGNORED):
4646

4747

4848
async def _test_same_task():
49-
loops = [None, None]
50-
assert isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
49+
assert not isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
5150

52-
def get_loop(i):
53-
loops[i] = (asyncio.get_event_loop(), asyncio.get_event_loop_policy())
51+
def get_loop(i, loop, policy):
52+
assert loop == asyncio.get_event_loop()
53+
assert policy == asyncio.get_event_loop_policy()
5454

5555
async with trio.open_nursery():
5656
async with trio_asyncio.open_loop() as loop1:
57+
policy = asyncio.get_event_loop_policy()
58+
assert isinstance(policy, trio_asyncio.TrioPolicy)
5759
async with trio_asyncio.open_loop() as loop2:
58-
loop1.call_later(0.1, get_loop, 0)
59-
loop2.call_later(0.1, get_loop, 1)
60+
assert policy == asyncio.get_event_loop_policy()
61+
loop1.call_later(0.1, get_loop, 0, loop1, policy)
62+
loop2.call_later(0.1, get_loop, 1, loop2, policy)
6063
await trio.sleep(0.2)
6164

62-
assert isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
65+
assert not isinstance(asyncio.get_event_loop_policy(), trio_asyncio.TrioPolicy)
6366
assert not isinstance(asyncio._get_running_loop(), trio_asyncio.TrioEventLoop)
64-
assert isinstance(loops[0][0], trio_asyncio.TrioEventLoop)
65-
assert isinstance(loops[1][0], trio_asyncio.TrioEventLoop)
66-
assert isinstance(loops[1][1], trio_asyncio.TrioPolicy)
67-
assert loops[0][0] is not loops[1][0]
68-
assert loops[0][1] is loops[1][1]
69-
assert loops[0][1] is asyncio.get_event_loop_policy()
7067

7168

7269
def test_same_task(old_policy):

trio_asyncio/async_.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,16 @@ def _main_loop_exit(self):
104104
super()._main_loop_exit()
105105
self._thread = None
106106

107-
from .loop import _current_loop
108-
outer_loop = _current_loop.loop
107+
from .loop import current_loop, current_policy, TrioPolicy
109108

110109
async with trio.open_nursery() as nursery:
110+
policy = current_policy.get()
111+
if not isinstance(policy, TrioPolicy):
112+
policy = TrioPolicy()
113+
old_policy = current_policy.set(policy)
114+
111115
loop = TrioEventLoop(queue_len=queue_len)
112-
_current_loop.loop = loop
116+
old_loop = current_loop.set(loop)
113117
try:
114118
loop._closed = False
115119
await loop._main_loop_init(nursery)
@@ -122,4 +126,6 @@ def _main_loop_exit(self):
122126
await loop._main_loop_exit()
123127
loop.close()
124128
nursery.cancel_scope.cancel()
125-
_current_loop.loop = outer_loop
129+
current_loop.reset(old_loop)
130+
current_policy.reset(old_policy)
131+

trio_asyncio/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def __init__(self, queue_len=None):
169169

170170
def __repr__(self):
171171
try:
172-
return "<%s running=%s>" % (
172+
return "<%s running=%s at 0x%x>" % (
173173
self.__class__.__name__, "closed" if self._closed else "no"
174-
if self._stopped.is_set() else "yes"
174+
if self._stopped.is_set() else "yes", id(self)
175175
)
176176
except Exception as exc:
177177
return "<%s ?:%s>" % (self.__class__.__name__, repr(exc))

trio_asyncio/loop.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import asyncio
77
import warnings
88
import threading
9+
from contextvars import ContextVar
910

1011
from .util import run_future
1112
from .async_ import TrioEventLoop, open_loop
@@ -30,8 +31,8 @@
3031
'TrioPolicy',
3132
]
3233

33-
_current_loop = trio.TaskLocal(loop=None, policy=None)
34-
34+
current_loop = ContextVar('trio_aio_loop', default=None)
35+
current_policy = ContextVar('trio_aio_policy', default=None)
3536

3637
class _TrioPolicy(asyncio.events.BaseDefaultEventLoopPolicy):
3738
_loop_factory = TrioEventLoop
@@ -73,28 +74,23 @@ def get_event_loop(self):
7374
# this creates a new loop in the main task
7475
return super().get_event_loop()
7576
else:
76-
return _current_loop.loop
77+
return current_loop.get()
7778

7879
@property
7980
def current_event_loop(self):
8081
"""The currently-running event loop, if one exists."""
81-
try:
82-
return _current_loop.loop
83-
except RuntimeError:
84-
# in the main thread this would create a new loop
85-
# return super().get_event_loop()
86-
return super().get_event_loop()
82+
loop = current_loop.get()
83+
if loop is None:
84+
loop = super().get_event_loop()
85+
return loop
8786

8887
def set_event_loop(self, loop):
8988
"""Set the current event loop."""
90-
try:
91-
_current_loop.loop = loop
92-
except RuntimeError:
93-
return super().set_event_loop(loop)
89+
current_loop.set(loop)
9490

9591

9692
# We need to monkey-patch asyncio's policy+loop getters to return our
97-
# TrioPolicy+loop whenever we are within Trio.
93+
# TrioPolicy and the current loop whenever we are within Trio.
9894

9995
from asyncio import events as _aio_event
10096

@@ -105,13 +101,13 @@ def set_event_loop(self, loop):
105101

106102
def _new_policy_get():
107103
try:
108-
policy = _current_loop.policy
104+
task = trio.hazmat.current_task()
109105
except RuntimeError:
110-
return _orig_policy_get()
111-
106+
policy = None
107+
else:
108+
policy = task.context[current_policy]
112109
if policy is None:
113-
policy = TrioPolicy()
114-
_current_loop.policy = policy
110+
policy = _orig_policy_get()
115111
return policy
116112

117113

@@ -125,9 +121,15 @@ def _new_policy_get():
125121

126122
def _new_run_get():
127123
try:
128-
return _current_loop.loop
124+
task = trio.hazmat.current_task()
129125
except RuntimeError:
130-
return _orig_run_get()
126+
loop = None
127+
else:
128+
loop = task.context[current_loop]
129+
130+
if loop is None:
131+
loop = _orig_run_get()
132+
return loop
131133

132134

133135
_aio_event._get_running_loop = _new_run_get
@@ -138,10 +140,10 @@ def _new_run_get():
138140

139141

140142
def _new_loop_get():
141-
try:
142-
return _current_loop.loop
143-
except RuntimeError:
144-
return _orig_loop_get()
143+
loop = _new_run_get()
144+
if loop is None:
145+
loop = _orig_loop_get()
146+
return loop
145147

146148

147149
_aio_event.get_event_loop = _new_loop_get
@@ -156,11 +158,11 @@ def _init_watcher(self):
156158
if self._watcher is None: # pragma: no branch
157159
self._watcher = TrioChildWatcher()
158160
if isinstance(threading.current_thread(), threading._MainThread):
159-
self._watcher.attach_loop(_current_loop.loop)
161+
self._watcher.attach_loop(current_loop.get())
160162

161163
if self._watcher is not None and \
162164
isinstance(threading.current_thread(), threading._MainThread):
163-
self._watcher.attach_loop(_current_loop.loop)
165+
self._watcher.attach_loop(current_loop.get())
164166

165167
def set_child_watcher(self, watcher):
166168
if watcher is not None:

0 commit comments

Comments
 (0)