You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

330 lines
9.6 KiB

4 years ago
  1. import asyncio
  2. import contextlib
  3. import warnings
  4. from collections.abc import Callable
  5. import pytest
  6. from aiohttp.helpers import isasyncgenfunction
  7. from aiohttp.web import Application
  8. from .test_utils import (BaseTestServer, RawTestServer, TestClient, TestServer,
  9. loop_context, setup_test_loop, teardown_test_loop)
  10. from .test_utils import unused_port as _unused_port
  11. try:
  12. import uvloop
  13. except ImportError: # pragma: no cover
  14. uvloop = None
  15. try:
  16. import tokio
  17. except ImportError: # pragma: no cover
  18. tokio = None
  19. def pytest_addoption(parser):
  20. parser.addoption(
  21. '--aiohttp-fast', action='store_true', default=False,
  22. help='run tests faster by disabling extra checks')
  23. parser.addoption(
  24. '--aiohttp-loop', action='store', default='pyloop',
  25. help='run tests with specific loop: pyloop, uvloop, tokio or all')
  26. parser.addoption(
  27. '--aiohttp-enable-loop-debug', action='store_true', default=False,
  28. help='enable event loop debug mode')
  29. def pytest_fixture_setup(fixturedef):
  30. """
  31. Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
  32. """
  33. func = fixturedef.func
  34. if isasyncgenfunction(func):
  35. # async generator fixture
  36. is_async_gen = True
  37. elif asyncio.iscoroutinefunction(func):
  38. # regular async fixture
  39. is_async_gen = False
  40. else:
  41. # not an async fixture, nothing to do
  42. return
  43. strip_request = False
  44. if 'request' not in fixturedef.argnames:
  45. fixturedef.argnames += ('request',)
  46. strip_request = True
  47. def wrapper(*args, **kwargs):
  48. request = kwargs['request']
  49. if strip_request:
  50. del kwargs['request']
  51. # if neither the fixture nor the test use the 'loop' fixture,
  52. # 'getfixturevalue' will fail because the test is not parameterized
  53. # (this can be removed someday if 'loop' is no longer parameterized)
  54. if 'loop' not in request.fixturenames:
  55. raise Exception(
  56. "Asynchronous fixtures must depend on the 'loop' fixture or "
  57. "be used in tests depending from it."
  58. )
  59. _loop = request.getfixturevalue('loop')
  60. if is_async_gen:
  61. # for async generators, we need to advance the generator once,
  62. # then advance it again in a finalizer
  63. gen = func(*args, **kwargs)
  64. def finalizer():
  65. try:
  66. return _loop.run_until_complete(gen.__anext__())
  67. except StopAsyncIteration: # NOQA
  68. pass
  69. request.addfinalizer(finalizer)
  70. return _loop.run_until_complete(gen.__anext__())
  71. else:
  72. return _loop.run_until_complete(func(*args, **kwargs))
  73. fixturedef.func = wrapper
  74. @pytest.fixture
  75. def fast(request):
  76. """--fast config option"""
  77. return request.config.getoption('--aiohttp-fast')
  78. @pytest.fixture
  79. def loop_debug(request):
  80. """--enable-loop-debug config option"""
  81. return request.config.getoption('--aiohttp-enable-loop-debug')
  82. @contextlib.contextmanager
  83. def _runtime_warning_context():
  84. """
  85. Context manager which checks for RuntimeWarnings, specifically to
  86. avoid "coroutine 'X' was never awaited" warnings being missed.
  87. If RuntimeWarnings occur in the context a RuntimeError is raised.
  88. """
  89. with warnings.catch_warnings(record=True) as _warnings:
  90. yield
  91. rw = ['{w.filename}:{w.lineno}:{w.message}'.format(w=w)
  92. for w in _warnings if w.category == RuntimeWarning]
  93. if rw:
  94. raise RuntimeError('{} Runtime Warning{},\n{}'.format(
  95. len(rw),
  96. '' if len(rw) == 1 else 's',
  97. '\n'.join(rw)
  98. ))
  99. @contextlib.contextmanager
  100. def _passthrough_loop_context(loop, fast=False):
  101. """
  102. setups and tears down a loop unless one is passed in via the loop
  103. argument when it's passed straight through.
  104. """
  105. if loop:
  106. # loop already exists, pass it straight through
  107. yield loop
  108. else:
  109. # this shadows loop_context's standard behavior
  110. loop = setup_test_loop()
  111. yield loop
  112. teardown_test_loop(loop, fast=fast)
  113. def pytest_pycollect_makeitem(collector, name, obj):
  114. """
  115. Fix pytest collecting for coroutines.
  116. """
  117. if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj):
  118. return list(collector._genfunctions(name, obj))
  119. def pytest_pyfunc_call(pyfuncitem):
  120. """
  121. Run coroutines in an event loop instead of a normal function call.
  122. """
  123. fast = pyfuncitem.config.getoption("--aiohttp-fast")
  124. if asyncio.iscoroutinefunction(pyfuncitem.function):
  125. existing_loop = pyfuncitem.funcargs.get('loop', None)
  126. with _runtime_warning_context():
  127. with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
  128. testargs = {arg: pyfuncitem.funcargs[arg]
  129. for arg in pyfuncitem._fixtureinfo.argnames}
  130. _loop.run_until_complete(pyfuncitem.obj(**testargs))
  131. return True
  132. def pytest_generate_tests(metafunc):
  133. if 'loop_factory' not in metafunc.fixturenames:
  134. return
  135. loops = metafunc.config.option.aiohttp_loop
  136. avail_factories = {'pyloop': asyncio.DefaultEventLoopPolicy}
  137. if uvloop is not None: # pragma: no cover
  138. avail_factories['uvloop'] = uvloop.EventLoopPolicy
  139. if tokio is not None: # pragma: no cover
  140. avail_factories['tokio'] = tokio.EventLoopPolicy
  141. if loops == 'all':
  142. loops = 'pyloop,uvloop?,tokio?'
  143. factories = {}
  144. for name in loops.split(','):
  145. required = not name.endswith('?')
  146. name = name.strip(' ?')
  147. if name not in avail_factories: # pragma: no cover
  148. if required:
  149. raise ValueError(
  150. "Unknown loop '%s', available loops: %s" % (
  151. name, list(factories.keys())))
  152. else:
  153. continue
  154. factories[name] = avail_factories[name]
  155. metafunc.parametrize("loop_factory",
  156. list(factories.values()),
  157. ids=list(factories.keys()))
  158. @pytest.fixture
  159. def loop(loop_factory, fast, loop_debug):
  160. """Return an instance of the event loop."""
  161. policy = loop_factory()
  162. asyncio.set_event_loop_policy(policy)
  163. with loop_context(fast=fast) as _loop:
  164. if loop_debug:
  165. _loop.set_debug(True) # pragma: no cover
  166. asyncio.set_event_loop(_loop)
  167. yield _loop
  168. @pytest.fixture
  169. def unused_port(aiohttp_unused_port): # pragma: no cover
  170. warnings.warn("Deprecated, use aiohttp_unused_port fixture instead",
  171. DeprecationWarning)
  172. return aiohttp_unused_port
  173. @pytest.fixture
  174. def aiohttp_unused_port():
  175. """Return a port that is unused on the current host."""
  176. return _unused_port
  177. @pytest.fixture
  178. def aiohttp_server(loop):
  179. """Factory to create a TestServer instance, given an app.
  180. aiohttp_server(app, **kwargs)
  181. """
  182. servers = []
  183. async def go(app, *, port=None, **kwargs):
  184. server = TestServer(app, port=port)
  185. await server.start_server(loop=loop, **kwargs)
  186. servers.append(server)
  187. return server
  188. yield go
  189. async def finalize():
  190. while servers:
  191. await servers.pop().close()
  192. loop.run_until_complete(finalize())
  193. @pytest.fixture
  194. def test_server(aiohttp_server): # pragma: no cover
  195. warnings.warn("Deprecated, use aiohttp_server fixture instead",
  196. DeprecationWarning)
  197. return aiohttp_server
  198. @pytest.fixture
  199. def aiohttp_raw_server(loop):
  200. """Factory to create a RawTestServer instance, given a web handler.
  201. aiohttp_raw_server(handler, **kwargs)
  202. """
  203. servers = []
  204. async def go(handler, *, port=None, **kwargs):
  205. server = RawTestServer(handler, port=port)
  206. await server.start_server(loop=loop, **kwargs)
  207. servers.append(server)
  208. return server
  209. yield go
  210. async def finalize():
  211. while servers:
  212. await servers.pop().close()
  213. loop.run_until_complete(finalize())
  214. @pytest.fixture
  215. def raw_test_server(aiohttp_raw_server): # pragma: no cover
  216. warnings.warn("Deprecated, use aiohttp_raw_server fixture instead",
  217. DeprecationWarning)
  218. return aiohttp_raw_server
  219. @pytest.fixture
  220. def aiohttp_client(loop):
  221. """Factory to create a TestClient instance.
  222. aiohttp_client(app, **kwargs)
  223. aiohttp_client(server, **kwargs)
  224. aiohttp_client(raw_server, **kwargs)
  225. """
  226. clients = []
  227. async def go(__param, *args, server_kwargs=None, **kwargs):
  228. if (isinstance(__param, Callable) and
  229. not isinstance(__param, (Application, BaseTestServer))):
  230. __param = __param(loop, *args, **kwargs)
  231. kwargs = {}
  232. else:
  233. assert not args, "args should be empty"
  234. if isinstance(__param, Application):
  235. server_kwargs = server_kwargs or {}
  236. server = TestServer(__param, loop=loop, **server_kwargs)
  237. client = TestClient(server, loop=loop, **kwargs)
  238. elif isinstance(__param, BaseTestServer):
  239. client = TestClient(__param, loop=loop, **kwargs)
  240. else:
  241. raise ValueError("Unknown argument type: %r" % type(__param))
  242. await client.start_server()
  243. clients.append(client)
  244. return client
  245. yield go
  246. async def finalize():
  247. while clients:
  248. await clients.pop().close()
  249. loop.run_until_complete(finalize())
  250. @pytest.fixture
  251. def test_client(aiohttp_client): # pragma: no cover
  252. warnings.warn("Deprecated, use aiohttp_client fixture instead",
  253. DeprecationWarning)
  254. return aiohttp_client