You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

112 lines
3.5 KiB

4 years ago
  1. import asyncio
  2. import socket
  3. from typing import Any, Dict, List, Optional
  4. from .abc import AbstractResolver
  5. from .helpers import get_running_loop
  6. __all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver')
  7. try:
  8. import aiodns
  9. # aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
  10. except ImportError: # pragma: no cover
  11. aiodns = None
  12. aiodns_default = False
  13. class ThreadedResolver(AbstractResolver):
  14. """Use Executor for synchronous getaddrinfo() calls, which defaults to
  15. concurrent.futures.ThreadPoolExecutor.
  16. """
  17. def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
  18. self._loop = get_running_loop(loop)
  19. async def resolve(self, host: str, port: int=0,
  20. family: int=socket.AF_INET) -> List[Dict[str, Any]]:
  21. infos = await self._loop.getaddrinfo(
  22. host, port, type=socket.SOCK_STREAM, family=family)
  23. hosts = []
  24. for family, _, proto, _, address in infos:
  25. hosts.append(
  26. {'hostname': host,
  27. 'host': address[0], 'port': address[1],
  28. 'family': family, 'proto': proto,
  29. 'flags': socket.AI_NUMERICHOST})
  30. return hosts
  31. async def close(self) -> None:
  32. pass
  33. class AsyncResolver(AbstractResolver):
  34. """Use the `aiodns` package to make asynchronous DNS lookups"""
  35. def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None,
  36. *args: Any, **kwargs: Any) -> None:
  37. if aiodns is None:
  38. raise RuntimeError("Resolver requires aiodns library")
  39. self._loop = get_running_loop(loop)
  40. self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs)
  41. if not hasattr(self._resolver, 'gethostbyname'):
  42. # aiodns 1.1 is not available, fallback to DNSResolver.query
  43. self.resolve = self._resolve_with_query # type: ignore
  44. async def resolve(self, host: str, port: int=0,
  45. family: int=socket.AF_INET) -> List[Dict[str, Any]]:
  46. try:
  47. resp = await self._resolver.gethostbyname(host, family)
  48. except aiodns.error.DNSError as exc:
  49. msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
  50. raise OSError(msg) from exc
  51. hosts = []
  52. for address in resp.addresses:
  53. hosts.append(
  54. {'hostname': host,
  55. 'host': address, 'port': port,
  56. 'family': family, 'proto': 0,
  57. 'flags': socket.AI_NUMERICHOST})
  58. if not hosts:
  59. raise OSError("DNS lookup failed")
  60. return hosts
  61. async def _resolve_with_query(
  62. self, host: str, port: int=0,
  63. family: int=socket.AF_INET) -> List[Dict[str, Any]]:
  64. if family == socket.AF_INET6:
  65. qtype = 'AAAA'
  66. else:
  67. qtype = 'A'
  68. try:
  69. resp = await self._resolver.query(host, qtype)
  70. except aiodns.error.DNSError as exc:
  71. msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
  72. raise OSError(msg) from exc
  73. hosts = []
  74. for rr in resp:
  75. hosts.append(
  76. {'hostname': host,
  77. 'host': rr.host, 'port': port,
  78. 'family': family, 'proto': 0,
  79. 'flags': socket.AI_NUMERICHOST})
  80. if not hosts:
  81. raise OSError("DNS lookup failed")
  82. return hosts
  83. async def close(self) -> None:
  84. return self._resolver.cancel()
  85. DefaultResolver = AsyncResolver if aiodns_default else ThreadedResolver