Skip to content

Commit e7fc624

Browse files
committed
Support using asyncio async generators in trio.
1 parent e1ae85d commit e7fc624

4 files changed

Lines changed: 70 additions & 4 deletions

File tree

trio_asyncio/adapter.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This code implements a clone of the asyncio mainloop which hooks into
22
# Trio.
33

4+
import inspect
45
import trio_asyncio
56

67
# import logging
@@ -12,9 +13,15 @@
1213

1314

1415
def trio2aio(proc):
15-
@wraps(proc)
16-
async def call(*args):
17-
return await trio_asyncio.run_asyncio(proc, *args)
16+
if inspect.isasyncgenfunction(proc):
17+
@wraps(proc)
18+
def call(*args):
19+
return trio_asyncio.wrap_generator(proc, *args)
20+
21+
else:
22+
@wraps(proc)
23+
async def call(*args):
24+
return await trio_asyncio.run_asyncio(proc, *args)
1825

1926
return call
2027

trio_asyncio/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
import inspect
34
import math
45
import trio
56
import heapq
@@ -11,7 +12,7 @@
1112

1213
from selectors import _BaseSelectorImpl, EVENT_READ, EVENT_WRITE
1314

14-
from .util import run_future
15+
from .util import run_future, run_generator
1516

1617
try:
1718
from trio.hazmat import wait_for_child
@@ -212,6 +213,11 @@ async def run_coroutine(self, coro):
212213
coro = asyncio.ensure_future(coro, loop=self)
213214
return await run_future(coro)
214215

216+
def wrap_generator(self, gen):
217+
# if inspect.isasyncgen(f):
218+
# return self.wrap_generator(f)
219+
return run_generator(self, gen())
220+
215221
async def run_asyncio(self, proc, *args):
216222
"""Run an asyncio function or method from Trio.
217223

trio_asyncio/loop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
'run_future',
2626
'run_coroutine',
2727
'run_asyncio',
28+
'wrap_generator',
2829
'TrioChildWatcher',
2930
'TrioPolicy',
3031
]
@@ -169,6 +170,13 @@ def __exit__(self, *tb):
169170
self.close()
170171

171172

173+
def wrap_generator(proc, *args):
174+
loop = asyncio.get_event_loop()
175+
if not isinstance(loop, TrioEventLoop):
176+
raise RuntimeError("Need to run in a trio_asyncio.open_loop() context")
177+
return loop.wrap_generator(proc, *args)
178+
179+
172180
async def run_asyncio(proc, *args):
173181
loop = asyncio.get_event_loop()
174182
if not isinstance(loop, TrioEventLoop):

trio_asyncio/util.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,48 @@ def abort_cb(raise_cancel_arg):
4545
sys.exc_info()[1].__cause__ = exc
4646
else:
4747
raise
48+
49+
50+
STOP = object()
51+
52+
async def run_generator(loop, async_generator):
53+
task = trio.hazmat.current_task()
54+
raise_cancel = None
55+
56+
async def consume_next():
57+
try:
58+
item = await async_generator.__anext__()
59+
except StopAsyncIteration:
60+
item = STOP
61+
62+
trio.hazmat.reschedule(task, trio.hazmat.Value(value=item))
63+
#trio.hazmat.reschedule(task, STOP)
64+
65+
def abort_cb(raise_cancel_arg):
66+
# Save the cancel-raising function
67+
nonlocal raise_cancel
68+
raise_cancel = raise_cancel_arg
69+
# XXX: we need to cancel any actice consume_next() call.
70+
# Keep waiting
71+
return trio.hazmat.Abort.FAILED
72+
73+
try:
74+
while True:
75+
# schedule that we read the next one from the iterator
76+
asyncio.ensure_future(consume_next(), loop=loop)
77+
78+
item = await trio.hazmat.wait_task_rescheduled(abort_cb)
79+
if item == STOP:
80+
break
81+
yield item
82+
83+
except asyncio.CancelledError as exc:
84+
if raise_cancel is not None:
85+
try:
86+
raise_cancel()
87+
finally:
88+
# Try to preserve the exception chain,
89+
# for more detailed tracebacks
90+
sys.exc_info()[1].__cause__ = exc
91+
else:
92+
raise

0 commit comments

Comments
 (0)