- import asyncio
- import contextlib
- import warnings
- from collections.abc import Callable
- import pytest
- from aiohttp.helpers import isasyncgenfunction
- from aiohttp.web import Application
- from .test_utils import (BaseTestServer, RawTestServer, TestClient, TestServer,
- loop_context, setup_test_loop, teardown_test_loop)
- from .test_utils import unused_port as _unused_port
- try:
- import uvloop
- except ImportError: # pragma: no cover
- uvloop = None
- try:
- import tokio
- except ImportError: # pragma: no cover
- tokio = None
- def pytest_addoption(parser):
- parser.addoption(
- '--aiohttp-fast', action='store_true', default=False,
- help='run tests faster by disabling extra checks')
- parser.addoption(
- '--aiohttp-loop', action='store', default='pyloop',
- help='run tests with specific loop: pyloop, uvloop, tokio or all')
- parser.addoption(
- '--aiohttp-enable-loop-debug', action='store_true', default=False,
- help='enable event loop debug mode')
- def pytest_fixture_setup(fixturedef):
- """
- Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
- """
- func = fixturedef.func
- if isasyncgenfunction(func):
- # async generator fixture
- is_async_gen = True
- elif asyncio.iscoroutinefunction(func):
- # regular async fixture
- is_async_gen = False
- else:
- # not an async fixture, nothing to do
- return
- strip_request = False
- if 'request' not in fixturedef.argnames:
- fixturedef.argnames += ('request',)
- strip_request = True
- def wrapper(*args, **kwargs):
- request = kwargs['request']
- if strip_request:
- del kwargs['request']
- # if neither the fixture nor the test use the 'loop' fixture,
- # 'getfixturevalue' will fail because the test is not parameterized
- # (this can be removed someday if 'loop' is no longer parameterized)
- if 'loop' not in request.fixturenames:
- raise Exception(
- "Asynchronous fixtures must depend on the 'loop' fixture or "
- "be used in tests depending from it."
- )
- _loop = request.getfixturevalue('loop')
- if is_async_gen:
- # for async generators, we need to advance the generator once,
- # then advance it again in a finalizer
- gen = func(*args, **kwargs)
- def finalizer():
- try:
- return _loop.run_until_complete(gen.__anext__())
- except StopAsyncIteration: # NOQA
- pass
- request.addfinalizer(finalizer)
- return _loop.run_until_complete(gen.__anext__())
- else:
- return _loop.run_until_complete(func(*args, **kwargs))
- fixturedef.func = wrapper
- @pytest.fixture
- def fast(request):
- """--fast config option"""
- return request.config.getoption('--aiohttp-fast')
- @pytest.fixture
- def loop_debug(request):
- """--enable-loop-debug config option"""
- return request.config.getoption('--aiohttp-enable-loop-debug')
- @contextlib.contextmanager
- def _runtime_warning_context():
- """
- Context manager which checks for RuntimeWarnings, specifically to
- avoid "coroutine 'X' was never awaited" warnings being missed.
- If RuntimeWarnings occur in the context a RuntimeError is raised.
- """
- with warnings.catch_warnings(record=True) as _warnings:
- yield
- rw = ['{w.filename}:{w.lineno}:{w.message}'.format(w=w)
- for w in _warnings if w.category == RuntimeWarning]
- if rw:
- raise RuntimeError('{} Runtime Warning{},\n{}'.format(
- len(rw),
- '' if len(rw) == 1 else 's',
- '\n'.join(rw)
- ))
- @contextlib.contextmanager
- def _passthrough_loop_context(loop, fast=False):
- """
- setups and tears down a loop unless one is passed in via the loop
- argument when it's passed straight through.
- """
- if loop:
- # loop already exists, pass it straight through
- yield loop
- else:
- # this shadows loop_context's standard behavior
- loop = setup_test_loop()
- yield loop
- teardown_test_loop(loop, fast=fast)
- def pytest_pycollect_makeitem(collector, name, obj):
- """
- Fix pytest collecting for coroutines.
- """
- if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj):
- return list(collector._genfunctions(name, obj))
- def pytest_pyfunc_call(pyfuncitem):
- """
- Run coroutines in an event loop instead of a normal function call.
- """
- fast = pyfuncitem.config.getoption("--aiohttp-fast")
- if asyncio.iscoroutinefunction(pyfuncitem.function):
- existing_loop = pyfuncitem.funcargs.get('loop', None)
- with _runtime_warning_context():
- with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
- testargs = {arg: pyfuncitem.funcargs[arg]
- for arg in pyfuncitem._fixtureinfo.argnames}
- _loop.run_until_complete(pyfuncitem.obj(**testargs))
- return True
- def pytest_generate_tests(metafunc):
- if 'loop_factory' not in metafunc.fixturenames:
- return
- loops = metafunc.config.option.aiohttp_loop
- avail_factories = {'pyloop': asyncio.DefaultEventLoopPolicy}
- if uvloop is not None: # pragma: no cover
- avail_factories['uvloop'] = uvloop.EventLoopPolicy
- if tokio is not None: # pragma: no cover
- avail_factories['tokio'] = tokio.EventLoopPolicy
- if loops == 'all':
- loops = 'pyloop,uvloop?,tokio?'
- factories = {}
- for name in loops.split(','):
- required = not name.endswith('?')
- name = name.strip(' ?')
- if name not in avail_factories: # pragma: no cover
- if required:
- raise ValueError(
- "Unknown loop '%s', available loops: %s" % (
- name, list(factories.keys())))
- else:
- continue
- factories[name] = avail_factories[name]
- metafunc.parametrize("loop_factory",
- list(factories.values()),
- ids=list(factories.keys()))
- @pytest.fixture
- def loop(loop_factory, fast, loop_debug):
- """Return an instance of the event loop."""
- policy = loop_factory()
- asyncio.set_event_loop_policy(policy)
- with loop_context(fast=fast) as _loop:
- if loop_debug:
- _loop.set_debug(True) # pragma: no cover
- asyncio.set_event_loop(_loop)
- yield _loop
- @pytest.fixture
- def unused_port(aiohttp_unused_port): # pragma: no cover
- warnings.warn("Deprecated, use aiohttp_unused_port fixture instead",
- DeprecationWarning)
- return aiohttp_unused_port
- @pytest.fixture
- def aiohttp_unused_port():
- """Return a port that is unused on the current host."""
- return _unused_port
- @pytest.fixture
- def aiohttp_server(loop):
- """Factory to create a TestServer instance, given an app.
- aiohttp_server(app, **kwargs)
- """
- servers = []
- async def go(app, *, port=None, **kwargs):
- server = TestServer(app, port=port)
- await server.start_server(loop=loop, **kwargs)
- servers.append(server)
- return server
- yield go
- async def finalize():
- while servers:
- await servers.pop().close()
- loop.run_until_complete(finalize())
- @pytest.fixture
- def test_server(aiohttp_server): # pragma: no cover
- warnings.warn("Deprecated, use aiohttp_server fixture instead",
- DeprecationWarning)
- return aiohttp_server
- @pytest.fixture
- def aiohttp_raw_server(loop):
- """Factory to create a RawTestServer instance, given a web handler.
- aiohttp_raw_server(handler, **kwargs)
- """
- servers = []
- async def go(handler, *, port=None, **kwargs):
- server = RawTestServer(handler, port=port)
- await server.start_server(loop=loop, **kwargs)
- servers.append(server)
- return server
- yield go
- async def finalize():
- while servers:
- await servers.pop().close()
- loop.run_until_complete(finalize())
- @pytest.fixture
- def raw_test_server(aiohttp_raw_server): # pragma: no cover
- warnings.warn("Deprecated, use aiohttp_raw_server fixture instead",
- DeprecationWarning)
- return aiohttp_raw_server
- @pytest.fixture
- def aiohttp_client(loop):
- """Factory to create a TestClient instance.
- aiohttp_client(app, **kwargs)
- aiohttp_client(server, **kwargs)
- aiohttp_client(raw_server, **kwargs)
- """
- clients = []
- async def go(__param, *args, server_kwargs=None, **kwargs):
- if (isinstance(__param, Callable) and
- not isinstance(__param, (Application, BaseTestServer))):
- __param = __param(loop, *args, **kwargs)
- kwargs = {}
- else:
- assert not args, "args should be empty"
- if isinstance(__param, Application):
- server_kwargs = server_kwargs or {}
- server = TestServer(__param, loop=loop, **server_kwargs)
- client = TestClient(server, loop=loop, **kwargs)
- elif isinstance(__param, BaseTestServer):
- client = TestClient(__param, loop=loop, **kwargs)
- else:
- raise ValueError("Unknown argument type: %r" % type(__param))
- await client.start_server()
- clients.append(client)
- return client
- yield go
- async def finalize():
- while clients:
- await clients.pop().close()
- loop.run_until_complete(finalize())
- @pytest.fixture
- def test_client(aiohttp_client): # pragma: no cover
- warnings.warn("Deprecated, use aiohttp_client fixture instead",
- DeprecationWarning)
- return aiohttp_client