|
30 | 30 | 'TrioPolicy', |
31 | 31 | ] |
32 | 32 |
|
| 33 | +_current_loop = trio.TaskLocal(loop=None, policy=None) |
33 | 34 |
|
34 | 35 | class _TrioPolicy(asyncio.events.BaseDefaultEventLoopPolicy): |
35 | 36 | _loop_factory = TrioEventLoop |
36 | 37 |
|
37 | 38 | def __init__(self): |
38 | 39 | super().__init__() |
39 | | - self._trio_local = trio.TaskLocal(_loop=None, _task=False) |
| 40 | + current_loop = trio.TaskLocal(_loop=None, _task=False) |
40 | 41 |
|
41 | 42 | def new_event_loop(self): |
42 | 43 | try: |
@@ -75,45 +76,68 @@ def get_event_loop(self): |
75 | 76 | # this creates a new loop in the main task |
76 | 77 | return super().get_event_loop() |
77 | 78 | else: |
78 | | - return self._trio_local._loop |
| 79 | + return _current_loop.loop |
79 | 80 |
|
80 | 81 | @property |
81 | 82 | def current_event_loop(self): |
82 | 83 | """The currently-running event loop, if one exists.""" |
83 | 84 | try: |
84 | | - return self._trio_local._loop |
| 85 | + return _current_loop.loop |
85 | 86 | except RuntimeError: |
86 | 87 | # in the main thread this would create a new loop |
87 | 88 | # return super().get_event_loop() |
88 | | - return self._local._loop |
| 89 | + return super().get_event_loop() |
89 | 90 |
|
90 | 91 | def set_event_loop(self, loop): |
91 | 92 | """Set the current event loop.""" |
92 | 93 | try: |
93 | | - task = trio.hazmat.current_task() |
| 94 | + _current_loop.loop = loop |
94 | 95 | except RuntimeError: |
95 | 96 | return super().set_event_loop(loop) |
96 | 97 |
|
97 | | - # This test will not trigger if you create a new asyncio event loop |
98 | | - # in a sub-task, which is exactly what we intend to be possible |
99 | | - if self._trio_local._loop is not None and loop is not None and \ |
100 | | - self._trio_local._task == task: |
101 | | - raise RuntimeError('You cannot replace an event loop.', self._trio_local._loop, loop) |
102 | | - self._trio_local._loop = loop |
103 | | - self._trio_local._task = task |
| 98 | + |
| 99 | + |
| 100 | +# We need to monkey-patch asyncio's policy+loop getters to return our |
| 101 | +# TrioPolicy+loop whenever we are within Trio. |
| 102 | + |
| 103 | +from asyncio import events as _aio_event |
| 104 | + |
| 105 | +_orig_policy_get = _aio_event.get_event_loop_policy |
| 106 | +def _new_policy_get(): |
| 107 | + try: |
| 108 | + policy = _current_loop.policy |
| 109 | + except RuntimeError: |
| 110 | + return _orig_policy_get() |
| 111 | + |
| 112 | + if policy is None: |
| 113 | + policy = TrioPolicy() |
| 114 | + _current_loop.policy = policy |
| 115 | + return policy |
| 116 | +_aio_event.get_event_loop_policy = _new_policy_get |
| 117 | +asyncio.get_event_loop_policy = _new_policy_get |
| 118 | + |
| 119 | +_orig_loop_get = _aio_event._get_running_loop |
| 120 | +def _new_loop_get(): |
| 121 | + try: |
| 122 | + return _current_loop.loop |
| 123 | + except RuntimeError: |
| 124 | + return _orig_loop_get() |
| 125 | +_aio_event._get_running_loop = _new_loop_get |
104 | 126 |
|
105 | 127 |
|
106 | 128 | class TrioPolicy(_TrioPolicy, asyncio.DefaultEventLoopPolicy): |
| 129 | + """This is the loop policy that's active whenever we're in a Trio context.""" |
| 130 | + |
107 | 131 | def _init_watcher(self): |
108 | 132 | with asyncio.events._lock: |
109 | 133 | if self._watcher is None: # pragma: no branch |
110 | 134 | self._watcher = TrioChildWatcher() |
111 | 135 | if isinstance(threading.current_thread(), threading._MainThread): |
112 | | - self._watcher.attach_loop(self._trio_local._loop) |
| 136 | + self._watcher.attach_loop(_current_loop.loop) |
113 | 137 |
|
114 | 138 | if self._watcher is not None and \ |
115 | 139 | isinstance(threading.current_thread(), threading._MainThread): |
116 | | - self._watcher.attach_loop(self._trio_local._loop) |
| 140 | + self._watcher.attach_loop(_current_loop.loop) |
117 | 141 |
|
118 | 142 | def set_child_watcher(self, watcher): |
119 | 143 | if watcher is not None: |
|
0 commit comments