You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
3.2 KiB

4 years ago
  1. import asyncio
  2. import socket
  3. from .abc import AbstractResolver
  4. __all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver')
  5. try:
  6. import aiodns
  7. # aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
  8. except ImportError: # pragma: no cover
  9. aiodns = None
  10. aiodns_default = False
  11. class ThreadedResolver(AbstractResolver):
  12. """Use Executor for synchronous getaddrinfo() calls, which defaults to
  13. concurrent.futures.ThreadPoolExecutor.
  14. """
  15. def __init__(self, loop=None):
  16. if loop is None:
  17. loop = asyncio.get_event_loop()
  18. self._loop = loop
  19. async def resolve(self, host, port=0, family=socket.AF_INET):
  20. infos = await self._loop.getaddrinfo(
  21. host, port, type=socket.SOCK_STREAM, family=family)
  22. hosts = []
  23. for family, _, proto, _, address in infos:
  24. hosts.append(
  25. {'hostname': host,
  26. 'host': address[0], 'port': address[1],
  27. 'family': family, 'proto': proto,
  28. 'flags': socket.AI_NUMERICHOST})
  29. return hosts
  30. async def close(self):
  31. pass
  32. class AsyncResolver(AbstractResolver):
  33. """Use the `aiodns` package to make asynchronous DNS lookups"""
  34. def __init__(self, loop=None, *args, **kwargs):
  35. if loop is None:
  36. loop = asyncio.get_event_loop()
  37. if aiodns is None:
  38. raise RuntimeError("Resolver requires aiodns library")
  39. self._loop = loop
  40. self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs)
  41. if not hasattr(self._resolver, 'gethostbyname'):
  42. # aiodns 1.1 is not available, fallback to DNSResolver.query
  43. self.resolve = self._resolve_with_query
  44. async def resolve(self, host, port=0, family=socket.AF_INET):
  45. try:
  46. resp = await self._resolver.gethostbyname(host, family)
  47. except aiodns.error.DNSError as exc:
  48. msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
  49. raise OSError(msg) from exc
  50. hosts = []
  51. for address in resp.addresses:
  52. hosts.append(
  53. {'hostname': host,
  54. 'host': address, 'port': port,
  55. 'family': family, 'proto': 0,
  56. 'flags': socket.AI_NUMERICHOST})
  57. if not hosts:
  58. raise OSError("DNS lookup failed")
  59. return hosts
  60. async def _resolve_with_query(self, host, port=0, family=socket.AF_INET):
  61. if family == socket.AF_INET6:
  62. qtype = 'AAAA'
  63. else:
  64. qtype = 'A'
  65. try:
  66. resp = await self._resolver.query(host, qtype)
  67. except aiodns.error.DNSError as exc:
  68. msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
  69. raise OSError(msg) from exc
  70. hosts = []
  71. for rr in resp:
  72. hosts.append(
  73. {'hostname': host,
  74. 'host': rr.host, 'port': port,
  75. 'family': family, 'proto': 0,
  76. 'flags': socket.AI_NUMERICHOST})
  77. if not hosts:
  78. raise OSError("DNS lookup failed")
  79. return hosts
  80. async def close(self):
  81. return self._resolver.cancel()
  82. DefaultResolver = AsyncResolver if aiodns_default else ThreadedResolver