Skip to content

Commit daf81a9

Browse files
committed
Merge pull request #18
2 parents 3a06dcf + e1feef8 commit daf81a9

6 files changed

Lines changed: 292 additions & 18 deletions

File tree

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ def all_tasks(loop=None):
2424

2525
asyncio.all_tasks = all_tasks
2626

27+
if not hasattr(asyncio, 'create_task'):
28+
29+
if hasattr(asyncio.events, 'get_running_loop'):
30+
def create_task(coro):
31+
loop = asyncio.events.get_running_loop()
32+
return loop.create_task(coro)
33+
else:
34+
def create_task(coro):
35+
loop = asyncio.events._get_running_loop()
36+
return loop.create_task(coro)
37+
38+
asyncio.create_task = create_task
2739

2840
@pytest.fixture
2941
async def loop():

tests/interop/test_calls.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@ class Seen:
99
flag = 0
1010

1111

12+
async def async_gen_to_list(generator):
13+
result = []
14+
async for item in generator:
15+
result.append(item)
16+
return result
17+
18+
1219
class TestCalls(aiotest.TestCase):
1320
async def call_t_a(self, proc, *args, loop=None):
1421
"""called from Trio"""
@@ -231,3 +238,24 @@ def err_asyncio():
231238
with pytest.raises(RuntimeError) as err:
232239
await self.call_t_a(err_asyncio, loop=loop)
233240
assert err.value.args[0] == "I has an owie"
241+
242+
@pytest.mark.trio
243+
async def test_trio_asyncio_generator(self, loop):
244+
async def dly_asyncio():
245+
yield 1
246+
await asyncio.sleep(0.01, loop=loop)
247+
yield 2
248+
249+
res = await async_gen_to_list(loop.wrap_generator(dly_asyncio))
250+
assert res == [1, 2]
251+
252+
@pytest.mark.trio
253+
async def test_trio_asyncio_generator_with_error(self, loop):
254+
async def dly_asyncio():
255+
yield 1
256+
raise RuntimeError("I has an owie")
257+
yield 2
258+
259+
with pytest.raises(RuntimeError) as err:
260+
await async_gen_to_list(loop.wrap_generator(dly_asyncio))
261+
assert err.value.args[0] == "I has an owie"

tests/python/test_tasks.py

Lines changed: 240 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
import collections
44
import contextlib
5+
import contextvars
56
import functools
67
import gc
78
import io
89
import os
910
import re
1011
import sys
1112
import types
13+
import random
1214
import unittest
15+
import textwrap
1316
import weakref
1417
from unittest import mock
1518
import pytest
@@ -72,6 +75,20 @@ def __call__(self, *args):
7275
pass
7376

7477

78+
class CoroLikeObject:
79+
def send(self, v):
80+
raise StopIteration(42)
81+
82+
def throw(self, *exc):
83+
pass
84+
85+
def close(self):
86+
pass
87+
88+
def __await__(self):
89+
return self
90+
91+
7592
class BaseTaskTests:
7693

7794
Task = None
@@ -2027,6 +2044,158 @@ def coro():
20272044

20282045
self.assertEqual(asyncio.all_tasks(self.loop), set())
20292046

2047+
def test_create_task_with_noncoroutine(self):
2048+
with self.assertRaisesRegex(TypeError,
2049+
"a coroutine was expected, got 123"):
2050+
self.new_task(self.loop, 123)
2051+
2052+
# test it for the second time to ensure that caching
2053+
# in asyncio.iscoroutine() doesn't break things.
2054+
with self.assertRaisesRegex(TypeError,
2055+
"a coroutine was expected, got 123"):
2056+
self.new_task(self.loop, 123)
2057+
2058+
def test_create_task_with_oldstyle_coroutine(self):
2059+
2060+
@asyncio.coroutine
2061+
def coro():
2062+
pass
2063+
2064+
task = self.new_task(self.loop, coro())
2065+
self.assertIsInstance(task, asyncio.Task)
2066+
self.loop.run_until_complete(task)
2067+
2068+
# test it for the second time to ensure that caching
2069+
# in asyncio.iscoroutine() doesn't break things.
2070+
task = self.new_task(self.loop, coro())
2071+
self.assertIsInstance(task, asyncio.Task)
2072+
self.loop.run_until_complete(task)
2073+
2074+
def test_create_task_with_async_function(self):
2075+
2076+
async def coro():
2077+
pass
2078+
2079+
task = self.new_task(self.loop, coro())
2080+
self.assertIsInstance(task, asyncio.Task)
2081+
self.loop.run_until_complete(task)
2082+
2083+
# test it for the second time to ensure that caching
2084+
# in asyncio.iscoroutine() doesn't break things.
2085+
task = self.new_task(self.loop, coro())
2086+
self.assertIsInstance(task, asyncio.Task)
2087+
self.loop.run_until_complete(task)
2088+
2089+
def test_create_task_with_asynclike_function(self):
2090+
task = self.new_task(self.loop, CoroLikeObject())
2091+
self.assertIsInstance(task, asyncio.Task)
2092+
self.assertEqual(self.loop.run_until_complete(task), 42)
2093+
2094+
# test it for the second time to ensure that caching
2095+
# in asyncio.iscoroutine() doesn't break things.
2096+
task = self.new_task(self.loop, CoroLikeObject())
2097+
self.assertIsInstance(task, asyncio.Task)
2098+
self.assertEqual(self.loop.run_until_complete(task), 42)
2099+
2100+
def test_bare_create_task(self):
2101+
2102+
async def inner():
2103+
return 1
2104+
2105+
async def coro():
2106+
task = asyncio.create_task(inner())
2107+
self.assertIsInstance(task, asyncio.Task)
2108+
ret = await task
2109+
self.assertEqual(1, ret)
2110+
2111+
self.loop.run_until_complete(coro())
2112+
2113+
def test_context_1(self):
2114+
cvar = contextvars.ContextVar('cvar', default='nope')
2115+
2116+
async def sub():
2117+
import pdb;pdb.set_trace()
2118+
await asyncio.sleep(0.1, loop=loop)
2119+
self.assertEqual(cvar.get(), 'nope')
2120+
cvar.set('something else')
2121+
2122+
async def main():
2123+
self.assertEqual(cvar.get(), 'nope')
2124+
subtask = self.new_task(loop, sub())
2125+
cvar.set('yes')
2126+
self.assertEqual(cvar.get(), 'yes')
2127+
await subtask
2128+
self.assertEqual(cvar.get(), 'yes')
2129+
2130+
loop = asyncio.new_event_loop()
2131+
try:
2132+
task = self.new_task(loop, main())
2133+
loop.run_until_complete(task)
2134+
finally:
2135+
loop.close()
2136+
2137+
def test_context_2(self):
2138+
cvar = contextvars.ContextVar('cvar', default='nope')
2139+
2140+
async def main():
2141+
def fut_on_done(fut):
2142+
# This change must not pollute the context
2143+
# of the "main()" task.
2144+
cvar.set('something else')
2145+
2146+
self.assertEqual(cvar.get(), 'nope')
2147+
2148+
for j in range(2):
2149+
fut = self.new_future(loop)
2150+
fut.add_done_callback(fut_on_done)
2151+
cvar.set(f'yes{j}')
2152+
loop.call_soon(fut.set_result, None)
2153+
await fut
2154+
self.assertEqual(cvar.get(), f'yes{j}')
2155+
2156+
for i in range(3):
2157+
# Test that task passed its context to add_done_callback:
2158+
cvar.set(f'yes{i}-{j}')
2159+
await asyncio.sleep(0.001, loop=loop)
2160+
self.assertEqual(cvar.get(), f'yes{i}-{j}')
2161+
2162+
loop = asyncio.new_event_loop()
2163+
try:
2164+
task = self.new_task(loop, main())
2165+
loop.run_until_complete(task)
2166+
finally:
2167+
loop.close()
2168+
2169+
self.assertEqual(cvar.get(), 'nope')
2170+
2171+
def test_context_3(self):
2172+
# Run 100 Tasks in parallel, each modifying cvar.
2173+
2174+
cvar = contextvars.ContextVar('cvar', default=-1)
2175+
2176+
async def sub(num):
2177+
for i in range(10):
2178+
cvar.set(num + i)
2179+
await asyncio.sleep(
2180+
random.uniform(0.001, 0.05), loop=loop)
2181+
self.assertEqual(cvar.get(), num + i)
2182+
2183+
async def main():
2184+
tasks = []
2185+
for i in range(100):
2186+
task = loop.create_task(sub(random.randint(0, 10)))
2187+
tasks.append(task)
2188+
2189+
await asyncio.gather(*tasks, loop=loop)
2190+
2191+
loop = asyncio.new_event_loop()
2192+
try:
2193+
loop.run_until_complete(main())
2194+
finally:
2195+
loop.close()
2196+
2197+
self.assertEqual(cvar.get(), -1)
2198+
20302199

20312200
def add_subclass_tests(cls):
20322201
BaseTask = cls.Task
@@ -2067,7 +2236,7 @@ async def func():
20672236
self.loop.call_soon(lambda: fut.set_result('spam'))
20682237
return await fut
20692238

2070-
task = self.Task(func(), loop=self.loop)
2239+
task = asyncio.Task(func(), loop=self.loop)
20712240

20722241
result = self.loop.run_until_complete(task)
20732242

@@ -2484,22 +2653,29 @@ def test_run_coroutine_threadsafe_task_cancelled(self):
24842653
with self.assertRaises(asyncio.CancelledError):
24852654
self.loop.run_until_complete(future)
24862655

2487-
@unittest.skip("XXX does not terminate")
2656+
@unittest.skip("trio-asyncio doesn't use a task factory")
24882657
def test_run_coroutine_threadsafe_task_factory_exception(self):
24892658
"""Test coroutine submission from a tread to an event loop
24902659
when the task factory raise an exception."""
2491-
# Schedule the target
2492-
future = self.loop.run_in_executor(None, lambda: self.target(advance_coro=True))
2493-
# Set corrupted task factory
2494-
self.loop.set_task_factory(lambda loop, coro: wrong_name) # noqa: F821
2660+
2661+
def task_factory(loop, coro):
2662+
raise NameError
2663+
2664+
run = self.loop.run_in_executor(
2665+
None, lambda: self.target(advance_coro=True))
2666+
24952667
# Set exception handler
24962668
callback = test_utils.MockCallback()
24972669
self.loop.set_exception_handler(callback)
2670+
2671+
# Set corrupted task factory
2672+
self.loop.set_task_factory(task_factory)
2673+
24982674
# Run event loop
24992675
with self.assertRaises(NameError) as exc_context:
2500-
self.loop.run_until_complete(future)
2676+
self.loop.run_until_complete(run)
2677+
25012678
# Check exceptions
2502-
self.assertIn('wrong_name', exc_context.exception.args[0])
25032679
self.assertEqual(len(callback.call_args_list), 1)
25042680
(loop, context), kwargs = callback.call_args
25052681
self.assertEqual(context['exception'], exc_context.exception)
@@ -2534,6 +2710,62 @@ def coro():
25342710
self.loop.run_until_complete(coro())
25352711
self.assertEqual(result, 11)
25362712

2713+
class CompatibilityTests(test_utils.TestCase):
2714+
# Tests for checking a bridge between old-styled coroutines
2715+
# and async/await syntax
2716+
2717+
def setUp(self):
2718+
super().setUp()
2719+
self.loop = asyncio.new_event_loop()
2720+
asyncio.set_event_loop(None)
2721+
2722+
def tearDown(self):
2723+
self.loop.close()
2724+
self.loop = None
2725+
super().tearDown()
2726+
2727+
def test_yield_from_awaitable(self):
2728+
2729+
@asyncio.coroutine
2730+
def coro():
2731+
yield from asyncio.sleep(0, loop=self.loop)
2732+
return 'ok'
2733+
2734+
result = self.loop.run_until_complete(coro())
2735+
self.assertEqual('ok', result)
2736+
2737+
def test_await_old_style_coro(self):
2738+
2739+
@asyncio.coroutine
2740+
def coro1():
2741+
return 'ok1'
2742+
2743+
@asyncio.coroutine
2744+
def coro2():
2745+
yield from asyncio.sleep(0, loop=self.loop)
2746+
return 'ok2'
2747+
async def inner():
2748+
return await asyncio.gather(coro1(), coro2(), loop=self.loop)
2749+
2750+
result = self.loop.run_until_complete(inner())
2751+
self.assertEqual(['ok1', 'ok2'], result)
2752+
2753+
def test_debug_mode_interop(self):
2754+
# https://bugs.python.org/issue32636
2755+
code = textwrap.dedent("""
2756+
import asyncio
2757+
2758+
async def native_coro():
2759+
pass
2760+
2761+
@asyncio.coroutine
2762+
def old_style_coro():
2763+
yield from native_coro()
2764+
2765+
asyncio.run(old_style_coro())
2766+
""")
2767+
assert_python_ok("-c", code, PYTHONASYNCIODEBUG="1")
2768+
25372769

25382770
if __name__ == '__main__':
25392771
unittest.main()

tests/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -524,16 +524,13 @@ def new_test_loop(self, gen=None):
524524
self.set_event_loop(loop)
525525
return loop
526526

527-
def unpatch_get_running_loop(self):
528-
events._get_running_loop = self._get_running_loop
529-
530527
def setUp(self):
531-
self._get_running_loop = events._get_running_loop
532-
events._get_running_loop = lambda: None
528+
#self._get_running_loop = events._get_running_loop
529+
#events._get_running_loop = lambda: None
533530
self._thread_cleanup = support.threading_setup()
534531

535532
def tearDown(self):
536-
self.unpatch_get_running_loop()
533+
#events._get_running_loop = self._get_running_loop
537534

538535
events.set_event_loop(None)
539536

trio_asyncio/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,10 @@ async def _main_loop_one(self, no_wait=False):
705705
# Don't go through the expensive nursery dance
706706
# if this is a sync function.
707707
if getattr(obj, '_is_sync', True):
708-
obj._callback(*obj._args)
708+
if hasattr(obj, '_context'):
709+
obj._context.run(obj._callback, *obj._args)
710+
else:
711+
obj._callback(*obj._args)
709712
else:
710713
await self._nursery.start(obj._call_async)
711714

0 commit comments

Comments
 (0)