|
30 | 30 | 'TrioPolicy', |
31 | 31 | ] |
32 | 32 |
|
| 33 | +_current_loop = trio.TaskLocal(loop=None, policy=None) |
| 34 | + |
33 | 35 |
|
34 | 36 | class _TrioPolicy(asyncio.events.BaseDefaultEventLoopPolicy): |
35 | 37 | _loop_factory = TrioEventLoop |
36 | 38 |
|
37 | | - def __init__(self): |
38 | | - super().__init__() |
39 | | - self._trio_local = trio.hazmat.RunLocal(_loop=None, _task=False) |
40 | | - |
41 | 39 | def new_event_loop(self): |
42 | 40 | try: |
43 | 41 | trio.hazmat.current_task() |
@@ -75,45 +73,94 @@ def get_event_loop(self): |
75 | 73 | # this creates a new loop in the main task |
76 | 74 | return super().get_event_loop() |
77 | 75 | else: |
78 | | - return self._trio_local._loop |
| 76 | + return _current_loop.loop |
79 | 77 |
|
80 | 78 | @property |
81 | 79 | def current_event_loop(self): |
82 | 80 | """The currently-running event loop, if one exists.""" |
83 | 81 | try: |
84 | | - return self._trio_local._loop |
| 82 | + return _current_loop.loop |
85 | 83 | except RuntimeError: |
86 | 84 | # in the main thread this would create a new loop |
87 | 85 | # return super().get_event_loop() |
88 | | - return self._local._loop |
| 86 | + return super().get_event_loop() |
89 | 87 |
|
90 | 88 | def set_event_loop(self, loop): |
91 | 89 | """Set the current event loop.""" |
92 | 90 | try: |
93 | | - task = trio.hazmat.current_task() |
| 91 | + _current_loop.loop = loop |
94 | 92 | except RuntimeError: |
95 | 93 | return super().set_event_loop(loop) |
96 | 94 |
|
97 | | - # This test will not trigger if you create a new asyncio event loop |
98 | | - # in a sub-task, which is exactly what we intend to be possible |
99 | | - if self._trio_local._loop is not None and loop is not None and \ |
100 | | - self._trio_local._task == task: |
101 | | - raise RuntimeError('You cannot replace an event loop.', self._trio_local._loop, loop) |
102 | | - self._trio_local._loop = loop |
103 | | - self._trio_local._task = task |
| 95 | + |
| 96 | +# We need to monkey-patch asyncio's policy+loop getters to return our |
| 97 | +# TrioPolicy+loop whenever we are within Trio. |
| 98 | + |
| 99 | +from asyncio import events as _aio_event |
| 100 | + |
| 101 | +##### |
| 102 | + |
| 103 | +_orig_policy_get = _aio_event.get_event_loop_policy |
| 104 | + |
| 105 | + |
| 106 | +def _new_policy_get(): |
| 107 | + try: |
| 108 | + policy = _current_loop.policy |
| 109 | + except RuntimeError: |
| 110 | + return _orig_policy_get() |
| 111 | + |
| 112 | + if policy is None: |
| 113 | + policy = TrioPolicy() |
| 114 | + _current_loop.policy = policy |
| 115 | + return policy |
| 116 | + |
| 117 | + |
| 118 | +_aio_event.get_event_loop_policy = _new_policy_get |
| 119 | +asyncio.get_event_loop_policy = _new_policy_get |
| 120 | + |
| 121 | +##### |
| 122 | + |
| 123 | +_orig_run_get = _aio_event._get_running_loop |
| 124 | + |
| 125 | + |
| 126 | +def _new_run_get(): |
| 127 | + try: |
| 128 | + return _current_loop.loop |
| 129 | + except RuntimeError: |
| 130 | + return _orig_run_get() |
| 131 | + |
| 132 | + |
| 133 | +_aio_event._get_running_loop = _new_run_get |
| 134 | + |
| 135 | +##### |
| 136 | + |
| 137 | +_orig_loop_get = _aio_event.get_event_loop |
| 138 | + |
| 139 | + |
| 140 | +def _new_loop_get(): |
| 141 | + try: |
| 142 | + return _current_loop.loop |
| 143 | + except RuntimeError: |
| 144 | + return _orig_loop_get() |
| 145 | + |
| 146 | + |
| 147 | +_aio_event.get_event_loop = _new_loop_get |
| 148 | +asyncio.get_event_loop = _new_loop_get |
104 | 149 |
|
105 | 150 |
|
106 | 151 | class TrioPolicy(_TrioPolicy, asyncio.DefaultEventLoopPolicy): |
| 152 | + """This is the loop policy that's active whenever we're in a Trio context.""" |
| 153 | + |
107 | 154 | def _init_watcher(self): |
108 | 155 | with asyncio.events._lock: |
109 | 156 | if self._watcher is None: # pragma: no branch |
110 | 157 | self._watcher = TrioChildWatcher() |
111 | 158 | if isinstance(threading.current_thread(), threading._MainThread): |
112 | | - self._watcher.attach_loop(self._trio_local._loop) |
| 159 | + self._watcher.attach_loop(_current_loop.loop) |
113 | 160 |
|
114 | 161 | if self._watcher is not None and \ |
115 | 162 | isinstance(threading.current_thread(), threading._MainThread): |
116 | | - self._watcher.attach_loop(self._trio_local._loop) |
| 163 | + self._watcher.attach_loop(_current_loop.loop) |
117 | 164 |
|
118 | 165 | def set_child_watcher(self, watcher): |
119 | 166 | if watcher is not None: |
|
0 commit comments