|
|
- # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
- # Copyright (C) 2003-2017 Nominum, Inc.
- #
- # Permission to use, copy, modify, and distribute this software and its
- # documentation for any purpose with or without fee is hereby granted,
- # provided that the above copyright notice and this permission notice
- # appear in all copies.
- #
- # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
- # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
- # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
- # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
- # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
- # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
- # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-
- """Talk to a DNS server."""
-
- from __future__ import generators
-
- import errno
- import select
- import socket
- import struct
- import sys
- import time
-
- import dns.exception
- import dns.inet
- import dns.name
- import dns.message
- import dns.rcode
- import dns.rdataclass
- import dns.rdatatype
- from ._compat import long, string_types, PY3
-
- if PY3:
- select_error = OSError
- else:
- select_error = select.error
-
- # Function used to create a socket. Can be overridden if needed in special
- # situations.
- socket_factory = socket.socket
-
- class UnexpectedSource(dns.exception.DNSException):
- """A DNS query response came from an unexpected address or port."""
-
-
- class BadResponse(dns.exception.FormError):
- """A DNS query response does not respond to the question asked."""
-
-
- class TransferError(dns.exception.DNSException):
- """A zone transfer response got a non-zero rcode."""
-
- def __init__(self, rcode):
- message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode)
- super(TransferError, self).__init__(message)
- self.rcode = rcode
-
-
- def _compute_expiration(timeout):
- if timeout is None:
- return None
- else:
- return time.time() + timeout
-
- # This module can use either poll() or select() as the "polling backend".
- #
- # A backend function takes an fd, bools for readability, writablity, and
- # error detection, and a timeout.
-
- def _poll_for(fd, readable, writable, error, timeout):
- """Poll polling backend."""
-
- event_mask = 0
- if readable:
- event_mask |= select.POLLIN
- if writable:
- event_mask |= select.POLLOUT
- if error:
- event_mask |= select.POLLERR
-
- pollable = select.poll()
- pollable.register(fd, event_mask)
-
- if timeout:
- event_list = pollable.poll(long(timeout * 1000))
- else:
- event_list = pollable.poll()
-
- return bool(event_list)
-
-
- def _select_for(fd, readable, writable, error, timeout):
- """Select polling backend."""
-
- rset, wset, xset = [], [], []
-
- if readable:
- rset = [fd]
- if writable:
- wset = [fd]
- if error:
- xset = [fd]
-
- if timeout is None:
- (rcount, wcount, xcount) = select.select(rset, wset, xset)
- else:
- (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
-
- return bool((rcount or wcount or xcount))
-
-
- def _wait_for(fd, readable, writable, error, expiration):
- # Use the selected polling backend to wait for any of the specified
- # events. An "expiration" absolute time is converted into a relative
- # timeout.
-
- done = False
- while not done:
- if expiration is None:
- timeout = None
- else:
- timeout = expiration - time.time()
- if timeout <= 0.0:
- raise dns.exception.Timeout
- try:
- if not _polling_backend(fd, readable, writable, error, timeout):
- raise dns.exception.Timeout
- except select_error as e:
- if e.args[0] != errno.EINTR:
- raise e
- done = True
-
-
- def _set_polling_backend(fn):
- # Internal API. Do not use.
-
- global _polling_backend
-
- _polling_backend = fn
-
- if hasattr(select, 'poll'):
- # Prefer poll() on platforms that support it because it has no
- # limits on the maximum value of a file descriptor (plus it will
- # be more efficient for high values).
- _polling_backend = _poll_for
- else:
- _polling_backend = _select_for
-
-
- def _wait_for_readable(s, expiration):
- _wait_for(s, True, False, True, expiration)
-
-
- def _wait_for_writable(s, expiration):
- _wait_for(s, False, True, True, expiration)
-
-
- def _addresses_equal(af, a1, a2):
- # Convert the first value of the tuple, which is a textual format
- # address into binary form, so that we are not confused by different
- # textual representations of the same address
- try:
- n1 = dns.inet.inet_pton(af, a1[0])
- n2 = dns.inet.inet_pton(af, a2[0])
- except dns.exception.SyntaxError:
- return False
- return n1 == n2 and a1[1:] == a2[1:]
-
-
- def _destination_and_source(af, where, port, source, source_port):
- # Apply defaults and compute destination and source tuples
- # suitable for use in connect(), sendto(), or bind().
- if af is None:
- try:
- af = dns.inet.af_for_address(where)
- except Exception:
- af = dns.inet.AF_INET
- if af == dns.inet.AF_INET:
- destination = (where, port)
- if source is not None or source_port != 0:
- if source is None:
- source = '0.0.0.0'
- source = (source, source_port)
- elif af == dns.inet.AF_INET6:
- destination = (where, port, 0, 0)
- if source is not None or source_port != 0:
- if source is None:
- source = '::'
- source = (source, source_port, 0, 0)
- return (af, destination, source)
-
-
- def send_udp(sock, what, destination, expiration=None):
- """Send a DNS message to the specified UDP socket.
-
- *sock*, a ``socket``.
-
- *what*, a ``binary`` or ``dns.message.Message``, the message to send.
-
- *destination*, a destination tuple appropriate for the address family
- of the socket, specifying where to send the query.
-
- *expiration*, a ``float`` or ``None``, the absolute time at which
- a timeout exception should be raised. If ``None``, no timeout will
- occur.
-
- Returns an ``(int, float)`` tuple of bytes sent and the sent time.
- """
-
- if isinstance(what, dns.message.Message):
- what = what.to_wire()
- _wait_for_writable(sock, expiration)
- sent_time = time.time()
- n = sock.sendto(what, destination)
- return (n, sent_time)
-
-
- def receive_udp(sock, destination, expiration=None,
- ignore_unexpected=False, one_rr_per_rrset=False,
- keyring=None, request_mac=b'', ignore_trailing=False):
- """Read a DNS message from a UDP socket.
-
- *sock*, a ``socket``.
-
- *destination*, a destination tuple appropriate for the address family
- of the socket, specifying where the associated query was sent.
-
- *expiration*, a ``float`` or ``None``, the absolute time at which
- a timeout exception should be raised. If ``None``, no timeout will
- occur.
-
- *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
- unexpected sources.
-
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
- RRset.
-
- *keyring*, a ``dict``, the keyring to use for TSIG.
-
- *request_mac*, a ``binary``, the MAC of the request (for TSIG).
-
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the received message.
-
- Raises if the message is malformed, if network errors occur, of if
- there is a timeout.
-
- Returns a ``dns.message.Message`` object.
- """
-
- wire = b''
- while 1:
- _wait_for_readable(sock, expiration)
- (wire, from_address) = sock.recvfrom(65535)
- if _addresses_equal(sock.family, from_address, destination) or \
- (dns.inet.is_multicast(destination[0]) and
- from_address[1:] == destination[1:]):
- break
- if not ignore_unexpected:
- raise UnexpectedSource('got a response from '
- '%s instead of %s' % (from_address,
- destination))
- received_time = time.time()
- r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing)
- return (r, received_time)
-
- def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
- ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False):
- """Return the response obtained after sending a query via UDP.
-
- *q*, a ``dns.message.Message``, the query to send
-
- *where*, a ``text`` containing an IPv4 or IPv6 address, where
- to send the message.
-
- *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
- query times out. If ``None``, the default, wait forever.
-
- *port*, an ``int``, the port send the message to. The default is 53.
-
- *af*, an ``int``, the address family to use. The default is ``None``,
- which causes the address family to use to be inferred from the form of
- *where*. If the inference attempt fails, AF_INET is used. This
- parameter is historical; you need never set it.
-
- *source*, a ``text`` containing an IPv4 or IPv6 address, specifying
- the source address. The default is the wildcard address.
-
- *source_port*, an ``int``, the port from which to send the message.
- The default is 0.
-
- *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
- unexpected sources.
-
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
- RRset.
-
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the received message.
-
- Returns a ``dns.message.Message``.
- """
-
- wire = q.to_wire()
- (af, destination, source) = _destination_and_source(af, where, port,
- source, source_port)
- s = socket_factory(af, socket.SOCK_DGRAM, 0)
- received_time = None
- sent_time = None
- try:
- expiration = _compute_expiration(timeout)
- s.setblocking(0)
- if source is not None:
- s.bind(source)
- (_, sent_time) = send_udp(s, wire, destination, expiration)
- (r, received_time) = receive_udp(s, destination, expiration,
- ignore_unexpected, one_rr_per_rrset,
- q.keyring, q.mac, ignore_trailing)
- finally:
- if sent_time is None or received_time is None:
- response_time = 0
- else:
- response_time = received_time - sent_time
- s.close()
- r.time = response_time
- if not q.is_response(r):
- raise BadResponse
- return r
-
-
- def _net_read(sock, count, expiration):
- """Read the specified number of bytes from sock. Keep trying until we
- either get the desired amount, or we hit EOF.
- A Timeout exception will be raised if the operation is not completed
- by the expiration time.
- """
- s = b''
- while count > 0:
- _wait_for_readable(sock, expiration)
- n = sock.recv(count)
- if n == b'':
- raise EOFError
- count = count - len(n)
- s = s + n
- return s
-
-
- def _net_write(sock, data, expiration):
- """Write the specified data to the socket.
- A Timeout exception will be raised if the operation is not completed
- by the expiration time.
- """
- current = 0
- l = len(data)
- while current < l:
- _wait_for_writable(sock, expiration)
- current += sock.send(data[current:])
-
-
- def send_tcp(sock, what, expiration=None):
- """Send a DNS message to the specified TCP socket.
-
- *sock*, a ``socket``.
-
- *what*, a ``binary`` or ``dns.message.Message``, the message to send.
-
- *expiration*, a ``float`` or ``None``, the absolute time at which
- a timeout exception should be raised. If ``None``, no timeout will
- occur.
-
- Returns an ``(int, float)`` tuple of bytes sent and the sent time.
- """
-
- if isinstance(what, dns.message.Message):
- what = what.to_wire()
- l = len(what)
- # copying the wire into tcpmsg is inefficient, but lets us
- # avoid writev() or doing a short write that would get pushed
- # onto the net
- tcpmsg = struct.pack("!H", l) + what
- _wait_for_writable(sock, expiration)
- sent_time = time.time()
- _net_write(sock, tcpmsg, expiration)
- return (len(tcpmsg), sent_time)
-
- def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
- keyring=None, request_mac=b'', ignore_trailing=False):
- """Read a DNS message from a TCP socket.
-
- *sock*, a ``socket``.
-
- *expiration*, a ``float`` or ``None``, the absolute time at which
- a timeout exception should be raised. If ``None``, no timeout will
- occur.
-
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
- RRset.
-
- *keyring*, a ``dict``, the keyring to use for TSIG.
-
- *request_mac*, a ``binary``, the MAC of the request (for TSIG).
-
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the received message.
-
- Raises if the message is malformed, if network errors occur, of if
- there is a timeout.
-
- Returns a ``dns.message.Message`` object.
- """
-
- ldata = _net_read(sock, 2, expiration)
- (l,) = struct.unpack("!H", ldata)
- wire = _net_read(sock, l, expiration)
- received_time = time.time()
- r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing)
- return (r, received_time)
-
- def _connect(s, address):
- try:
- s.connect(address)
- except socket.error:
- (ty, v) = sys.exc_info()[:2]
-
- if hasattr(v, 'errno'):
- v_err = v.errno
- else:
- v_err = v[0]
- if v_err not in [errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY]:
- raise v
-
-
- def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False):
- """Return the response obtained after sending a query via TCP.
-
- *q*, a ``dns.message.Message``, the query to send
-
- *where*, a ``text`` containing an IPv4 or IPv6 address, where
- to send the message.
-
- *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
- query times out. If ``None``, the default, wait forever.
-
- *port*, an ``int``, the port send the message to. The default is 53.
-
- *af*, an ``int``, the address family to use. The default is ``None``,
- which causes the address family to use to be inferred from the form of
- *where*. If the inference attempt fails, AF_INET is used. This
- parameter is historical; you need never set it.
-
- *source*, a ``text`` containing an IPv4 or IPv6 address, specifying
- the source address. The default is the wildcard address.
-
- *source_port*, an ``int``, the port from which to send the message.
- The default is 0.
-
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
- RRset.
-
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the received message.
-
- Returns a ``dns.message.Message``.
- """
-
- wire = q.to_wire()
- (af, destination, source) = _destination_and_source(af, where, port,
- source, source_port)
- s = socket_factory(af, socket.SOCK_STREAM, 0)
- begin_time = None
- received_time = None
- try:
- expiration = _compute_expiration(timeout)
- s.setblocking(0)
- begin_time = time.time()
- if source is not None:
- s.bind(source)
- _connect(s, destination)
- send_tcp(s, wire, expiration)
- (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
- q.keyring, q.mac, ignore_trailing)
- finally:
- if begin_time is None or received_time is None:
- response_time = 0
- else:
- response_time = received_time - begin_time
- s.close()
- r.time = response_time
- if not q.is_response(r):
- raise BadResponse
- return r
-
-
- def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
- timeout=None, port=53, keyring=None, keyname=None, relativize=True,
- af=None, lifetime=None, source=None, source_port=0, serial=0,
- use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
- """Return a generator for the responses to a zone transfer.
-
- *where*. If the inference attempt fails, AF_INET is used. This
- parameter is historical; you need never set it.
-
- *zone*, a ``dns.name.Name`` or ``text``, the name of the zone to transfer.
-
- *rdtype*, an ``int`` or ``text``, the type of zone transfer. The
- default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be
- used to do an incremental transfer instead.
-
- *rdclass*, an ``int`` or ``text``, the class of the zone transfer.
- The default is ``dns.rdataclass.IN``.
-
- *timeout*, a ``float``, the number of seconds to wait for each
- response message. If None, the default, wait forever.
-
- *port*, an ``int``, the port send the message to. The default is 53.
-
- *keyring*, a ``dict``, the keyring to use for TSIG.
-
- *keyname*, a ``dns.name.Name`` or ``text``, the name of the TSIG
- key to use.
-
- *relativize*, a ``bool``. If ``True``, all names in the zone will be
- relativized to the zone origin. It is essential that the
- relativize setting matches the one specified to
- ``dns.zone.from_xfr()`` if using this generator to make a zone.
-
- *af*, an ``int``, the address family to use. The default is ``None``,
- which causes the address family to use to be inferred from the form of
- *where*. If the inference attempt fails, AF_INET is used. This
- parameter is historical; you need never set it.
-
- *lifetime*, a ``float``, the total number of seconds to spend
- doing the transfer. If ``None``, the default, then there is no
- limit on the time the transfer may take.
-
- *source*, a ``text`` containing an IPv4 or IPv6 address, specifying
- the source address. The default is the wildcard address.
-
- *source_port*, an ``int``, the port from which to send the message.
- The default is 0.
-
- *serial*, an ``int``, the SOA serial number to use as the base for
- an IXFR diff sequence (only meaningful if *rdtype* is
- ``dns.rdatatype.IXFR``).
-
- *use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR).
-
- *keyalgorithm*, a ``dns.name.Name`` or ``text``, the TSIG algorithm to use.
-
- Raises on errors, and so does the generator.
-
- Returns a generator of ``dns.message.Message`` objects.
- """
-
- if isinstance(zone, string_types):
- zone = dns.name.from_text(zone)
- if isinstance(rdtype, string_types):
- rdtype = dns.rdatatype.from_text(rdtype)
- q = dns.message.make_query(zone, rdtype, rdclass)
- if rdtype == dns.rdatatype.IXFR:
- rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA',
- '. . %u 0 0 0 0' % serial)
- q.authority.append(rrset)
- if keyring is not None:
- q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
- wire = q.to_wire()
- (af, destination, source) = _destination_and_source(af, where, port,
- source, source_port)
- if use_udp:
- if rdtype != dns.rdatatype.IXFR:
- raise ValueError('cannot do a UDP AXFR')
- s = socket_factory(af, socket.SOCK_DGRAM, 0)
- else:
- s = socket_factory(af, socket.SOCK_STREAM, 0)
- s.setblocking(0)
- if source is not None:
- s.bind(source)
- expiration = _compute_expiration(lifetime)
- _connect(s, destination)
- l = len(wire)
- if use_udp:
- _wait_for_writable(s, expiration)
- s.send(wire)
- else:
- tcpmsg = struct.pack("!H", l) + wire
- _net_write(s, tcpmsg, expiration)
- done = False
- delete_mode = True
- expecting_SOA = False
- soa_rrset = None
- if relativize:
- origin = zone
- oname = dns.name.empty
- else:
- origin = None
- oname = zone
- tsig_ctx = None
- first = True
- while not done:
- mexpiration = _compute_expiration(timeout)
- if mexpiration is None or mexpiration > expiration:
- mexpiration = expiration
- if use_udp:
- _wait_for_readable(s, expiration)
- (wire, from_address) = s.recvfrom(65535)
- else:
- ldata = _net_read(s, 2, mexpiration)
- (l,) = struct.unpack("!H", ldata)
- wire = _net_read(s, l, mexpiration)
- is_ixfr = (rdtype == dns.rdatatype.IXFR)
- r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
- xfr=True, origin=origin, tsig_ctx=tsig_ctx,
- multi=True, first=first,
- one_rr_per_rrset=is_ixfr)
- rcode = r.rcode()
- if rcode != dns.rcode.NOERROR:
- raise TransferError(rcode)
- tsig_ctx = r.tsig_ctx
- first = False
- answer_index = 0
- if soa_rrset is None:
- if not r.answer or r.answer[0].name != oname:
- raise dns.exception.FormError(
- "No answer or RRset not for qname")
- rrset = r.answer[0]
- if rrset.rdtype != dns.rdatatype.SOA:
- raise dns.exception.FormError("first RRset is not an SOA")
- answer_index = 1
- soa_rrset = rrset.copy()
- if rdtype == dns.rdatatype.IXFR:
- if soa_rrset[0].serial <= serial:
- #
- # We're already up-to-date.
- #
- done = True
- else:
- expecting_SOA = True
- #
- # Process SOAs in the answer section (other than the initial
- # SOA in the first message).
- #
- for rrset in r.answer[answer_index:]:
- if done:
- raise dns.exception.FormError("answers after final SOA")
- if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
- if expecting_SOA:
- if rrset[0].serial != serial:
- raise dns.exception.FormError(
- "IXFR base serial mismatch")
- expecting_SOA = False
- elif rdtype == dns.rdatatype.IXFR:
- delete_mode = not delete_mode
- #
- # If this SOA RRset is equal to the first we saw then we're
- # finished. If this is an IXFR we also check that we're seeing
- # the record in the expected part of the response.
- #
- if rrset == soa_rrset and \
- (rdtype == dns.rdatatype.AXFR or
- (rdtype == dns.rdatatype.IXFR and delete_mode)):
- done = True
- elif expecting_SOA:
- #
- # We made an IXFR request and are expecting another
- # SOA RR, but saw something else, so this must be an
- # AXFR response.
- #
- rdtype = dns.rdatatype.AXFR
- expecting_SOA = False
- if done and q.keyring and not r.had_tsig:
- raise dns.exception.FormError("missing TSIG")
- yield r
- s.close()
|