|
2 | 2 | # Trio. |
3 | 3 |
|
4 | 4 | import types |
| 5 | +import warnings |
| 6 | + |
5 | 7 | from async_generator import isasyncgenfunction |
6 | 8 | import asyncio |
7 | 9 | import trio_asyncio |
8 | | -from contextvars import copy_context |
| 10 | +from contextvars import ContextVar |
| 11 | + |
| 12 | +from .util import run_aio_generator, run_aio_future |
| 13 | + |
| 14 | +current_loop = ContextVar('trio_aio_loop', default=None) |
| 15 | +current_policy = ContextVar('trio_aio_policy', default=None) |
9 | 16 |
|
10 | 17 | # import logging |
11 | 18 | # logger = logging.getLogger(__name__) |
12 | 19 |
|
13 | 20 | from functools import wraps, partial |
14 | 21 |
|
15 | | -__all__ = ['trio2aio', 'aio2trio', 'allow_asyncio'] |
| 22 | +__all__ = ['trio2aio', 'aio2trio', 'aio_as_trio', 'allow_asyncio', |
| 23 | + 'current_loop', 'current_policy'] |
16 | 24 |
|
17 | 25 |
|
18 | 26 | def trio2aio(proc): |
19 | | - if isasyncgenfunction(proc): |
| 27 | + """Call asyncio code from Trio. |
20 | 28 |
|
21 | | - @wraps(proc) |
22 | | - def call(*args, **kwargs): |
23 | | - proc_ = proc |
24 | | - if kwargs: |
25 | | - proc_ = partial(proc_, **kwargs) |
26 | | - return trio_asyncio.wrap_generator(proc_, *args) |
| 29 | + Deprecated: Use aio_as_trio() instead. |
27 | 30 |
|
28 | | - else: |
| 31 | + await loop.run_iterator(iter) |
29 | 32 |
|
30 | | - @wraps(proc) |
31 | | - async def call(*args, **kwargs): |
32 | | - proc_ = proc |
33 | | - if kwargs: |
34 | | - proc_ = partial(proc_, **kwargs) |
35 | | - return await trio_asyncio.run_asyncio(proc_, *args) |
| 33 | + simply call |
36 | 34 |
|
37 | | - return call |
| 35 | + await loop.run_asyncio(iter) |
| 36 | + """ |
| 37 | + warnings.warn("Use 'aio_as_trio(proc)' instead'", DeprecationWarning) |
| 38 | + |
| 39 | + return aio_as_trio(proc) |
| 40 | + |
| 41 | +class Asyncio_Trio_Wrapper: |
| 42 | + """ |
| 43 | + This wrapper object encapsulates an asyncio-style coroutine, |
| 44 | + generator, or iterator, to be called seamlessly from Trio. |
| 45 | + """ |
| 46 | + def __init__(self, proc, args=[], loop=None): |
| 47 | + self.proc = proc |
| 48 | + self.args = args |
| 49 | + self._loop = loop |
| 50 | + |
| 51 | + @property |
| 52 | + def loop(self): |
| 53 | + """The loop argument needs to be lazily evaluated.""" |
| 54 | + loop = self._loop |
| 55 | + if loop is None: |
| 56 | + loop = current_loop.get() |
| 57 | + return loop |
| 58 | + |
| 59 | + def __get__(self, obj, cls): |
| 60 | + """If this is used to decorate an instance, |
| 61 | + we need to forward the original ``self`` to the wrapped method. |
| 62 | + """ |
| 63 | + if obj is None: |
| 64 | + return self.__call__ |
| 65 | + return partial(self.__call__, obj) |
| 66 | + |
| 67 | + async def __call__(self, *args, **kwargs): |
| 68 | + if self.args: |
| 69 | + "Call 'aio_as_trio(oroc)(*args)', not 'aio_as_trio(proc, *args)'" |
| 70 | + |
| 71 | + f = self.proc(*args, **kwargs) |
| 72 | + return await self.loop.run_aio_coroutine(f) |
| 73 | + |
| 74 | + def __await__(self): |
| 75 | + """Compatbility code for loop.run_asyncio""" |
| 76 | + f = self.proc(*self.args) |
| 77 | + return self.loop.run_aio_coroutine(f).__await__() |
| 78 | + |
| 79 | + def __aenter__(self): |
| 80 | + proc_enter = getattr(self.proc, "__aenter__", None) |
| 81 | + if proc_enter is None or self.args: |
| 82 | + raise RuntimeError( |
| 83 | + "Call 'aio_as_trio(ctxfactory(*args))', not 'aio_as_trio(ctxfactory, *args)'" |
| 84 | + ) |
| 85 | + f = proc_enter() |
| 86 | + return self.loop.run_aio_coroutine(f) |
| 87 | + |
| 88 | + def __aexit__(self, *tb): |
| 89 | + f = self.proc.__aexit__(*tb) |
| 90 | + return self.loop.run_aio_coroutine(f) |
| 91 | + |
| 92 | + def __aiter__(self): |
| 93 | + proc_iter = getattr(self.proc, "__anext__", None) |
| 94 | + if proc_iter is None or self.args: |
| 95 | + raise RuntimeError( |
| 96 | + "Call 'run_asyncio(gen(*args))', not 'run_asyncio(gen, *args)'" |
| 97 | + ) |
| 98 | + return run_aio_generator(self.loop, self.proc) |
| 99 | + |
| 100 | + |
| 101 | +def aio_as_trio(proc, loop=None): |
| 102 | + return Asyncio_Trio_Wrapper(proc, loop=loop) |
38 | 103 |
|
39 | 104 |
|
40 | 105 | def aio2trio(proc): |
|
0 commit comments