- """ Multicast DNS Service Discovery for Python, v0.14-wmcbrine
- Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
- This module provides a framework for the use of DNS Service Discovery
- using IP multicast.
- This library is free software; you can redistribute it and/or
- modify it under the terms of the GNU Lesser General Public
- License as published by the Free Software Foundation; either
- version 2.1 of the License, or (at your option) any later version.
- This library is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- Lesser General Public License for more details.
- You should have received a copy of the GNU Lesser General Public
- License along with this library; if not, write to the Free Software
- Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
- """
- import enum
- import errno
- import logging
- import re
- import select
- import socket
- import struct
- import sys
- import threading
- import time
- from functools import reduce
- from typing import Callable # noqa # used in type hints
- from typing import Dict, List, Optional, Union
- import ifaddr
- __author__ = 'Paul Scott-Murphy, William McBrine'
- __maintainer__ = 'Jakub Stasiak <jakub@stasiak.at>'
- __version__ = '0.21.3'
- __license__ = 'LGPL'
- __all__ = [
- "__version__",
- "Zeroconf", "ServiceInfo", "ServiceBrowser",
- "Error", "InterfaceChoice", "ServiceStateChange",
- ]
- if sys.version_info <= (3, 3):
- raise ImportError('''
- Python version > 3.3 required for python-zeroconf.
- If you need support for Python 2 or Python 3.3 please use version 19.1
- ''')
- log = logging.getLogger(__name__)
- log.addHandler(logging.NullHandler())
- if log.level == logging.NOTSET:
- log.setLevel(logging.WARN)
- # Some timing constants
- _CHECK_TIME = 175
- # Some DNS constants
- _MDNS_ADDR = ''
- _MDNS_PORT = 5353
- _DNS_PORT = 53
- _DNS_TTL = 120 # two minutes default TTL as recommended by RFC6762
- _MAX_MSG_TYPICAL = 1460 # unused
- _FLAGS_QR_MASK = 0x8000 # query response mask
- _FLAGS_QR_QUERY = 0x0000 # query
- _FLAGS_QR_RESPONSE = 0x8000 # response
- _FLAGS_AA = 0x0400 # Authoritative answer
- _FLAGS_TC = 0x0200 # Truncated
- _FLAGS_RD = 0x0100 # Recursion desired
- _FLAGS_RA = 0x8000 # Recursion available
- _FLAGS_Z = 0x0040 # Zero
- _FLAGS_AD = 0x0020 # Authentic data
- _FLAGS_CD = 0x0010 # Checking disabled
- _CLASS_IN = 1
- _CLASS_CS = 2
- _CLASS_CH = 3
- _CLASS_HS = 4
- _CLASS_NONE = 254
- _CLASS_ANY = 255
- _CLASS_UNIQUE = 0x8000
- _TYPE_A = 1
- _TYPE_NS = 2
- _TYPE_MD = 3
- _TYPE_MF = 4
- _TYPE_SOA = 6
- _TYPE_MB = 7
- _TYPE_MG = 8
- _TYPE_MR = 9
- _TYPE_NULL = 10
- _TYPE_WKS = 11
- _TYPE_PTR = 12
- _TYPE_HINFO = 13
- _TYPE_MINFO = 14
- _TYPE_MX = 15
- _TYPE_TXT = 16
- _TYPE_AAAA = 28
- _TYPE_SRV = 33
- _TYPE_ANY = 255
- # Mapping constants to names
- _CLASSES = {_CLASS_IN: "in",
- _CLASS_CS: "cs",
- _CLASS_CH: "ch",
- _CLASS_HS: "hs",
- _CLASS_NONE: "none",
- _CLASS_ANY: "any"}
- _TYPES = {_TYPE_A: "a",
- _TYPE_NS: "ns",
- _TYPE_MD: "md",
- _TYPE_MF: "mf",
- _TYPE_CNAME: "cname",
- _TYPE_SOA: "soa",
- _TYPE_MB: "mb",
- _TYPE_MG: "mg",
- _TYPE_MR: "mr",
- _TYPE_NULL: "null",
- _TYPE_WKS: "wks",
- _TYPE_PTR: "ptr",
- _TYPE_HINFO: "hinfo",
- _TYPE_MINFO: "minfo",
- _TYPE_MX: "mx",
- _TYPE_TXT: "txt",
- _TYPE_AAAA: "quada",
- _TYPE_SRV: "srv",
- _TYPE_ANY: "any"}
- _HAS_A_TO_Z = re.compile(r'[A-Za-z]')
- _HAS_ONLY_A_TO_Z_NUM_HYPHEN = re.compile(r'^[A-Za-z0-9\-]+$')
- _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$')
- _HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]')
- int2byte = struct.Struct(">B").pack
- @enum.unique
- class InterfaceChoice(enum.Enum):
- Default = 1
- All = 2
- @enum.unique
- class ServiceStateChange(enum.Enum):
- Added = 1
- Removed = 2
- # utility functions
- def current_time_millis() -> float:
- """Current system time in milliseconds"""
- return time.time() * 1000
- def service_type_name(type_, *, allow_underscores: bool = False):
- """
- Validate a fully qualified service name, instance or subtype. [rfc6763]
- Returns fully qualified service name.
- Domain names used by mDNS-SD take the following forms:
- <sn> . <_tcp|_udp> . local.
- <Instance> . <sn> . <_tcp|_udp> . local.
- <sub>._sub . <sn> . <_tcp|_udp> . local.
- 1) must end with 'local.'
- This is true because we are implementing mDNS and since the 'm' means
- multi-cast, the 'local.' domain is mandatory.
- 2) local is preceded with either '_udp.' or '_tcp.'
- 3) service name <sn> precedes <_tcp|_udp>
- The rules for Service Names [RFC6335] state that they may be no more
- than fifteen characters long (not counting the mandatory underscore),
- consisting of only letters, digits, and hyphens, must begin and end
- with a letter or digit, must not contain consecutive hyphens, and
- must contain at least one letter.
- The instance name <Instance> and sub type <sub> may be up to 63 bytes.
- The portion of the Service Instance Name is a user-
- friendly name consisting of arbitrary Net-Unicode text [RFC5198]. It
- MUST NOT contain ASCII control characters (byte values 0x00-0x1F and
- 0x7F) [RFC20] but otherwise is allowed to contain any characters,
- without restriction, including spaces, uppercase, lowercase,
- punctuation -- including dots -- accented characters, non-Roman text,
- and anything else that may be represented using Net-Unicode.
- :param type_: Type, SubType or service name to validate
- :return: fully qualified service name (eg: _http._tcp.local.)
- """
- if not (type_.endswith('._tcp.local.') or type_.endswith('._udp.local.')):
- raise BadTypeInNameException(
- "Type '%s' must end with '._tcp.local.' or '._udp.local.'" %
- type_)
- remaining = type_[:-len('._tcp.local.')].split('.')
- name = remaining.pop()
- if not name:
- raise BadTypeInNameException("No Service name found")
- if len(remaining) == 1 and len(remaining[0]) == 0:
- raise BadTypeInNameException(
- "Type '%s' must not start with '.'" % type_)
- if name[0] != '_':
- raise BadTypeInNameException(
- "Service name (%s) must start with '_'" % name)
- # remove leading underscore
- name = name[1:]
- if len(name) > 15:
- raise BadTypeInNameException(
- "Service name (%s) must be <= 15 bytes" % name)
- if '--' in name:
- raise BadTypeInNameException(
- "Service name (%s) must not contain '--'" % name)
- if '-' in (name[0], name[-1]):
- raise BadTypeInNameException(
- "Service name (%s) may not start or end with '-'" % name)
- if not _HAS_A_TO_Z.search(name):
- raise BadTypeInNameException(
- "Service name (%s) must contain at least one letter (eg: 'A-Z')" %
- name)
- allowed_characters_re = (
- _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE if allow_underscores
- )
- if not allowed_characters_re.search(name):
- raise BadTypeInNameException(
- "Service name (%s) must contain only these characters: "
- "A-Z, a-z, 0-9, hyphen ('-')%s" % (name, ", underscore ('_')" if allow_underscores else ""))
- if remaining and remaining[-1] == '_sub':
- remaining.pop()
- if len(remaining) == 0 or len(remaining[0]) == 0:
- raise BadTypeInNameException(
- "_sub requires a subtype name")
- if len(remaining) > 1:
- remaining = ['.'.join(remaining)]
- if remaining:
- length = len(remaining[0].encode('utf-8'))
- if length > 63:
- raise BadTypeInNameException("Too long: '%s'" % remaining[0])
- if _HAS_ASCII_CONTROL_CHARS.search(remaining[0]):
- raise BadTypeInNameException(
- "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" %
- remaining[0])
- return '_' + name + type_[-len('._tcp.local.'):]
- # Exceptions
- class Error(Exception):
- pass
- class IncomingDecodeError(Error):
- pass
- class NonUniqueNameException(Error):
- pass
- class NamePartTooLongException(Error):
- pass
- class AbstractMethodException(Error):
- pass
- class BadTypeInNameException(Error):
- pass
- # implementation classes
- class QuietLogger:
- _seen_logs = {} # type: Dict[str, tuple]
- @classmethod
- def log_exception_warning(cls, logger_data=None):
- exc_info = sys.exc_info()
- exc_str = str(exc_info[1])
- if exc_str not in cls._seen_logs:
- # log at warning level the first time this is seen
- cls._seen_logs[exc_str] = exc_info
- logger = log.warning
- else:
- logger = log.debug
- if logger_data is not None:
- logger(*logger_data)
- logger('Exception occurred:', exc_info=exc_info)
- @classmethod
- def log_warning_once(cls, *args):
- msg_str = args[0]
- if msg_str not in cls._seen_logs:
- cls._seen_logs[msg_str] = 0
- logger = log.warning
- else:
- logger = log.debug
- cls._seen_logs[msg_str] += 1
- logger(*args)
- class DNSEntry:
- """A DNS entry"""
- def __init__(self, name, type_, class_):
- self.key = name.lower()
- self.name = name
- self.type = type_
- self.class_ = class_ & _CLASS_MASK
- self.unique = (class_ & _CLASS_UNIQUE) != 0
- def __eq__(self, other):
- """Equality test on name, type, and class"""
- return (isinstance(other, DNSEntry) and
- self.name == other.name and
- self.type == other.type and
- self.class_ == other.class_)
- def __ne__(self, other):
- """Non-equality test"""
- return not self.__eq__(other)
- @staticmethod
- def get_class_(class_):
- """Class accessor"""
- return _CLASSES.get(class_, "?(%s)" % class_)
- @staticmethod
- def get_type(t):
- """Type accessor"""
- return _TYPES.get(t, "?(%s)" % t)
- def to_string(self, hdr, other):
- """String representation with additional information"""
- result = "%s[%s,%s" % (hdr, self.get_type(self.type),
- self.get_class_(self.class_))
- if self.unique:
- result += "-unique,"
- else:
- result += ","
- result += self.name
- if other is not None:
- result += ",%s]" % other
- else:
- result += "]"
- return result
- class DNSQuestion(DNSEntry):
- """A DNS question entry"""
- def __init__(self, name: str, type_: int, class_: int) -> None:
- DNSEntry.__init__(self, name, type_, class_)
- def answered_by(self, rec: 'DNSRecord') -> bool:
- """Returns true if the question is answered by the record"""
- return (self.class_ == rec.class_ and
- (self.type == rec.type or self.type == _TYPE_ANY) and
- self.name == rec.name)
- def __repr__(self) -> str:
- """String representation"""
- return DNSEntry.to_string(self, "question", None)
- class DNSRecord(DNSEntry):
- """A DNS record - like a DNS entry, but has a TTL"""
- def __init__(self, name, type_, class_, ttl):
- DNSEntry.__init__(self, name, type_, class_)
- self.ttl = ttl
- self.created = current_time_millis()
- def __eq__(self, other):
- """Abstract method"""
- raise AbstractMethodException
- def __ne__(self, other):
- """Non-equality test"""
- return not self.__eq__(other)
- def suppressed_by(self, msg):
- """Returns true if any answer in a message can suffice for the
- information held in this record."""
- for record in msg.answers:
- if self.suppressed_by_answer(record):
- return True
- return False
- def suppressed_by_answer(self, other):
- """Returns true if another record has same name, type and class,
- and if its TTL is at least half of this record's."""
- return self == other and other.ttl > (self.ttl / 2)
- def get_expiration_time(self, percent):
- """Returns the time at which this record will have expired
- by a certain percentage."""
- return self.created + (percent * self.ttl * 10)
- def get_remaining_ttl(self, now):
- """Returns the remaining TTL in seconds."""
- return max(0, (self.get_expiration_time(100) - now) / 1000.0)
- def is_expired(self, now) -> bool:
- """Returns true if this record has expired."""
- return self.get_expiration_time(100) <= now
- def is_stale(self, now):
- """Returns true if this record is at least half way expired."""
- return self.get_expiration_time(50) <= now
- def reset_ttl(self, other):
- """Sets this record's TTL and created time to that of
- another record."""
- self.created = other.created
- self.ttl = other.ttl
- def write(self, out):
- """Abstract method"""
- raise AbstractMethodException
- def to_string(self, other):
- """String representation with additional information"""
- arg = "%s/%s,%s" % (
- self.ttl, self.get_remaining_ttl(current_time_millis()), other)
- return DNSEntry.to_string(self, "record", arg)
- class DNSAddress(DNSRecord):
- """A DNS address record"""
- def __init__(self, name, type_, class_, ttl, address):
- DNSRecord.__init__(self, name, type_, class_, ttl)
- self.address = address
- def write(self, out):
- """Used in constructing an outgoing packet"""
- out.write_string(self.address)
- def __eq__(self, other):
- """Tests equality on address"""
- return (isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and
- self.address == other.address)
- def __ne__(self, other):
- """Non-equality test"""
- return not self.__eq__(other)
- def __repr__(self):
- """String representation"""
- try:
- return str(socket.inet_ntoa(self.address))
- except Exception: # TODO stop catching all Exceptions
- return str(self.address)
- class DNSHinfo(DNSRecord):
- """A DNS host information record"""
- def __init__(self, name, type_, class_, ttl, cpu, os):
- DNSRecord.__init__(self, name, type_, class_, ttl)
- try:
- self.cpu = cpu.decode('utf-8')
- except AttributeError:
- self.cpu = cpu
- try:
- self.os = os.decode('utf-8')
- except AttributeError:
- self.os = os
- def write(self, out):
- """Used in constructing an outgoing packet"""
- out.write_character_string(self.cpu.encode('utf-8'))
- out.write_character_string(self.os.encode('utf-8'))
- def __eq__(self, other):
- """Tests equality on cpu and os"""
- return (isinstance(other, DNSHinfo) and DNSEntry.__eq__(self, other) and
- self.cpu == other.cpu and self.os == other.os)
- def __ne__(self, other):
- """Non-equality test"""
- return not self.__eq__(other)
- def __repr__(self):
- """String representation"""
- return self.cpu + " " + self.os
- class DNSPointer(DNSRecord):
- """A DNS pointer record"""
- def __init__(self, name, type_, class_, ttl, alias):
- DNSRecord.__init__(self, name, type_, class_, ttl)
- self.alias = alias
- def write(self, out):
- """Used in constructing an outgoing packet"""
- out.write_name(self.alias)
- def __eq__(self, other):
- """Tests equality on alias"""
- return (isinstance(other, DNSPointer) and DNSEntry.__eq__(self, other) and
- self.alias == other.alias)
- def __ne__(self, other):
- """Non-equality test"""
- return not self.__eq__(other)
- def __repr__(self):
- """String representation"""
- return self.to_string(self.alias)
- class DNSText(DNSRecord):
- """A DNS text record"""
- def __init__(self, name, type_, class_, ttl, text):
- assert isinstance(text, (bytes, type(None)))
- DNSRecord.__init__(self, name, type_, class_, ttl)
- self.text = text
- def write(self, out):
- """Used in constructing an outgoing packet"""
- out.write_string(self.text)
- def __eq__(self, other):
- """Tests equality on text"""
- return (isinstance(other, DNSText) and DNSEntry.__eq__(self, other) and
- self.text == other.text)
- def __ne__(self, other):
- """Non-equality test"""
- return not self.__eq__(other)
- def __repr__(self):
- """String representation"""
- if len(self.text) > 10:
- return self.to_string(self.text[:7]) + "..."
- else:
- return self.to_string(self.text)
- class DNSService(DNSRecord):
- """A DNS service record"""
- def __init__(self, name, type_, class_, ttl,
- priority, weight, port, server):
- DNSRecord.__init__(self, name, type_, class_, ttl)
- self.priority = priority
- self.weight = weight
- self.port = port
- self.server = server
- def write(self, out):
- """Used in constructing an outgoing packet"""
- out.write_short(self.priority)
- out.write_short(self.weight)
- out.write_short(self.port)
- out.write_name(self.server)
- def __eq__(self, other):
- """Tests equality on priority, weight, port and server"""
- return (isinstance(other, DNSService) and
- DNSEntry.__eq__(self, other) and
- self.priority == other.priority and
- self.weight == other.weight and
- self.port == other.port and
- self.server == other.server)
- def __ne__(self, other):
- """Non-equality test"""
- return not self.__eq__(other)
- def __repr__(self):
- """String representation"""
- return self.to_string("%s:%s" % (self.server, self.port))
- class DNSIncoming(QuietLogger):
- """Object representation of an incoming DNS packet"""
- def __init__(self, data):
- """Constructor from string holding bytes of packet"""
- self.offset = 0
- self.data = data
- self.questions = []
- self.answers = []
- self.id = 0
- self.flags = 0
- self.num_questions = 0
- self.num_answers = 0
- self.num_authorities = 0
- self.num_additionals = 0
- self.valid = False
- try:
- self.read_header()
- self.read_questions()
- self.read_others()
- self.valid = True
- except (IndexError, struct.error, IncomingDecodeError):
- self.log_exception_warning((
- 'Choked at offset %d while unpacking %r', self.offset, data))
- def unpack(self, format_):
- length = struct.calcsize(format_)
- info = struct.unpack(
- format_, self.data[self.offset:self.offset + length])
- self.offset += length
- return info
- def read_header(self):
- """Reads header portion of packet"""
- (self.id, self.flags, self.num_questions, self.num_answers,
- self.num_authorities, self.num_additionals) = self.unpack(b'!6H')
- def read_questions(self):
- """Reads questions section of packet"""
- for i in range(self.num_questions):
- name = self.read_name()
- type_, class_ = self.unpack(b'!HH')
- question = DNSQuestion(name, type_, class_)
- self.questions.append(question)
- # def read_int(self):
- # """Reads an integer from the packet"""
- # return self.unpack(b'!I')[0]
- def read_character_string(self):
- """Reads a character string from the packet"""
- length = self.data[self.offset]
- self.offset += 1
- return self.read_string(length)
- def read_string(self, length):
- """Reads a string of a given length from the packet"""
- info = self.data[self.offset:self.offset + length]
- self.offset += length
- return info
- def read_unsigned_short(self):
- """Reads an unsigned short from the packet"""
- return self.unpack(b'!H')[0]
- def read_others(self):
- """Reads the answers, authorities and additionals section of the
- packet"""
- n = self.num_answers + self.num_authorities + self.num_additionals
- for i in range(n):
- domain = self.read_name()
- type_, class_, ttl, length = self.unpack(b'!HHiH')
- rec = None
- if type_ == _TYPE_A:
- rec = DNSAddress(
- domain, type_, class_, ttl, self.read_string(4))
- elif type_ == _TYPE_CNAME or type_ == _TYPE_PTR:
- rec = DNSPointer(
- domain, type_, class_, ttl, self.read_name())
- elif type_ == _TYPE_TXT:
- rec = DNSText(
- domain, type_, class_, ttl, self.read_string(length))
- elif type_ == _TYPE_SRV:
- rec = DNSService(
- domain, type_, class_, ttl,
- self.read_unsigned_short(), self.read_unsigned_short(),
- self.read_unsigned_short(), self.read_name())
- elif type_ == _TYPE_HINFO:
- rec = DNSHinfo(
- domain, type_, class_, ttl,
- self.read_character_string(), self.read_character_string())
- elif type_ == _TYPE_AAAA:
- rec = DNSAddress(
- domain, type_, class_, ttl, self.read_string(16))
- else:
- # Try to ignore types we don't know about
- # Skip the payload for the resource record so the next
- # records can be parsed correctly
- self.offset += length
- if rec is not None:
- self.answers.append(rec)
- def is_query(self) -> bool:
- """Returns true if this is a query"""
- return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY
- def is_response(self):
- """Returns true if this is a response"""
- return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
- def read_utf(self, offset, length):
- """Reads a UTF-8 string of a given length from the packet"""
- return str(self.data[offset:offset + length], 'utf-8', 'replace')
- def read_name(self):
- """Reads a domain name from the packet"""
- result = ''
- off = self.offset
- next_ = -1
- first = off
- while True:
- length = self.data[off]
- off += 1
- if length == 0:
- break
- t = length & 0xC0
- if t == 0x00:
- result = ''.join((result, self.read_utf(off, length) + '.'))
- off += length
- elif t == 0xC0:
- if next_ < 0:
- next_ = off + 1
- off = ((length & 0x3F) << 8) | self.data[off]
- if off >= first:
- raise IncomingDecodeError(
- "Bad domain name (circular) at %s" % (off,))
- first = off
- else:
- raise IncomingDecodeError("Bad domain name at %s" % (off,))
- if next_ >= 0:
- self.offset = next_
- else:
- self.offset = off
- return result
- class DNSOutgoing:
- """Object representation of an outgoing packet"""
- def __init__(self, flags, multicast=True):
- self.finished = False
- self.id = 0
- self.multicast = multicast
- self.flags = flags
- self.names = {}
- self.data = []
- self.size = 12
- self.state = self.State.init
- self.questions = []
- self.answers = []
- self.authorities = []
- self.additionals = []
- def __repr__(self):
- return '<DNSOutgoing:{%s}>' % ', '.join([
- 'multicast=%s' % self.multicast,
- 'flags=%s' % self.flags,
- 'questions=%s' % self.questions,
- 'answers=%s' % self.answers,
- 'authorities=%s' % self.authorities,
- 'additionals=%s' % self.additionals,
- ])
- class State(enum.Enum):
- init = 0
- finished = 1
- def add_question(self, record):
- """Adds a question"""
- self.questions.append(record)
- def add_answer(self, inp, record):
- """Adds an answer"""
- if not record.suppressed_by(inp):
- self.add_answer_at_time(record, 0)
- def add_answer_at_time(self, record, now):
- """Adds an answer if it does not expire by a certain time"""
- if record is not None:
- if now == 0 or not record.is_expired(now):
- self.answers.append((record, now))
- def add_authorative_answer(self, record):
- """Adds an authoritative answer"""
- self.authorities.append(record)
- def add_additional_answer(self, record):
- """ Adds an additional answer
- From: RFC 6763, DNS-Based Service Discovery, February 2013
- 12. DNS Additional Record Generation
- DNS has an efficiency feature whereby a DNS server may place
- additional records in the additional section of the DNS message.
- These additional records are records that the client did not
- explicitly request, but the server has reasonable grounds to expect
- that the client might request them shortly, so including them can
- save the client from having to issue additional queries.
- This section recommends which additional records SHOULD be generated
- to improve network efficiency, for both Unicast and Multicast DNS-SD
- responses.
- 12.1. PTR Records
- When including a DNS-SD Service Instance Enumeration or Selective
- Instance Enumeration (subtype) PTR record in a response packet, the
- server/responder SHOULD include the following additional records:
- o The SRV record(s) named in the PTR rdata.
- o The TXT record(s) named in the PTR rdata.
- o All address records (type "A" and "AAAA") named in the SRV rdata.
- 12.2. SRV Records
- When including an SRV record in a response packet, the
- server/responder SHOULD include the following additional records:
- o All address records (type "A" and "AAAA") named in the SRV rdata.
- """
- self.additionals.append(record)
- def pack(self, format_, value):
- self.data.append(struct.pack(format_, value))
- self.size += struct.calcsize(format_)
- def write_byte(self, value):
- """Writes a single byte to the packet"""
- self.pack(b'!c', int2byte(value))
- def insert_short(self, index, value):
- """Inserts an unsigned short in a certain position in the packet"""
- self.data.insert(index, struct.pack(b'!H', value))
- self.size += 2
- def write_short(self, value):
- """Writes an unsigned short to the packet"""
- self.pack(b'!H', value)
- def write_int(self, value):
- """Writes an unsigned integer to the packet"""
- self.pack(b'!I', int(value))
- def write_string(self, value):
- """Writes a string to the packet"""
- assert isinstance(value, bytes)
- self.data.append(value)
- self.size += len(value)
- def write_utf(self, s):
- """Writes a UTF-8 string of a given length to the packet"""
- utfstr = s.encode('utf-8')
- length = len(utfstr)
- if length > 64:
- raise NamePartTooLongException
- self.write_byte(length)
- self.write_string(utfstr)
- def write_character_string(self, value):
- assert isinstance(value, bytes)
- length = len(value)
- if length > 256:
- raise NamePartTooLongException
- self.write_byte(length)
- self.write_string(value)
- def write_name(self, name):
- """
- Write names to packet
- 18.14. Name Compression
- When generating Multicast DNS messages, implementations SHOULD use
- name compression wherever possible to compress the names of resource
- records, by replacing some or all of the resource record name with a
- compact two-byte reference to an appearance of that data somewhere
- earlier in the message [RFC1035].
- """
- # split name into each label
- parts = name.split('.')
- if not parts[-1]:
- parts.pop()
- # construct each suffix
- name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))]
- # look for an existing name or suffix
- for count, sub_name in enumerate(name_suffices):
- if sub_name in self.names:
- break
- else:
- count = len(name_suffices)
- # note the new names we are saving into the packet
- name_length = len(name.encode('utf-8'))
- for suffix in name_suffices[:count]:
- self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1
- # write the new names out.
- for part in parts[:count]:
- self.write_utf(part)
- # if we wrote part of the name, create a pointer to the rest
- if count != len(name_suffices):
- # Found substring in packet, create pointer
- index = self.names[name_suffices[count]]
- self.write_byte((index >> 8) | 0xC0)
- self.write_byte(index & 0xFF)
- else:
- # this is the end of a name
- self.write_byte(0)
- def write_question(self, question):
- """Writes a question to the packet"""
- self.write_name(question.name)
- self.write_short(question.type)
- self.write_short(question.class_)
- def write_record(self, record, now):
- """Writes a record (answer, authoritative answer, additional) to
- the packet"""
- if self.state == self.State.finished:
- return 1
- start_data_length, start_size = len(self.data), self.size
- self.write_name(record.name)
- self.write_short(record.type)
- if record.unique and self.multicast:
- self.write_short(record.class_ | _CLASS_UNIQUE)
- else:
- self.write_short(record.class_)
- if now == 0:
- self.write_int(record.ttl)
- else:
- self.write_int(record.get_remaining_ttl(now))
- index = len(self.data)
- # Adjust size for the short we will write before this record
- self.size += 2
- record.write(self)
- self.size -= 2
- length = sum((len(d) for d in self.data[index:]))
- # Here is the short we adjusted for
- self.insert_short(index, length)
- # if we go over, then rollback and quit
- if self.size > _MAX_MSG_ABSOLUTE:
- while len(self.data) > start_data_length:
- self.data.pop()
- self.size = start_size
- self.state = self.State.finished
- return 1
- return 0
- def packet(self) -> bytes:
- """Returns a string containing the packet's bytes
- No further parts should be added to the packet once this
- is done."""
- overrun_answers, overrun_authorities, overrun_additionals = 0, 0, 0
- if self.state != self.State.finished:
- for question in self.questions:
- self.write_question(question)
- for answer, time_ in self.answers:
- overrun_answers += self.write_record(answer, time_)
- for authority in self.authorities:
- overrun_authorities += self.write_record(authority, 0)
- for additional in self.additionals:
- overrun_additionals += self.write_record(additional, 0)
- self.state = self.State.finished
- self.insert_short(0, len(self.additionals) - overrun_additionals)
- self.insert_short(0, len(self.authorities) - overrun_authorities)
- self.insert_short(0, len(self.answers) - overrun_answers)
- self.insert_short(0, len(self.questions))
- self.insert_short(0, self.flags)
- if self.multicast:
- self.insert_short(0, 0)
- else:
- self.insert_short(0, self.id)
- return b''.join(self.data)
- class DNSCache:
- """A cache of DNS entries"""
- def __init__(self):
- self.cache = {}
- def add(self, entry):
- """Adds an entry"""
- # Insert first in list so get returns newest entry
- self.cache.setdefault(entry.key, []).insert(0, entry)
- def remove(self, entry):
- """Removes an entry"""
- try:
- list_ = self.cache[entry.key]
- list_.remove(entry)
- except (KeyError, ValueError):
- pass
- def get(self, entry):
- """Gets an entry by key. Will return None if there is no
- matching entry."""
- try:
- list_ = self.cache[entry.key]
- for cached_entry in list_:
- if entry.__eq__(cached_entry):
- return cached_entry
- except (KeyError, ValueError):
- return None
- def get_by_details(self, name, type_, class_):
- """Gets an entry by details. Will return None if there is
- no matching entry."""
- entry = DNSEntry(name, type_, class_)
- return self.get(entry)
- def entries_with_name(self, name):
- """Returns a list of entries whose key matches the name."""
- try:
- return self.cache[name.lower()]
- except KeyError:
- return []
- def current_entry_with_name_and_alias(self, name, alias):
- now = current_time_millis()
- for record in self.entries_with_name(name):
- if (record.type == _TYPE_PTR and
- not record.is_expired(now) and
- record.alias == alias):
- return record
- def entries(self):
- """Returns a list of all entries"""
- if not self.cache:
- return []
- else:
- # avoid size change during iteration by copying the cache
- values = list(self.cache.values())
- return reduce(lambda a, b: a + b, values)
- class Engine(threading.Thread):
- """An engine wraps read access to sockets, allowing objects that
- need to receive data from sockets to be called back when the
- sockets are ready.
- A reader needs a handle_read() method, which is called when the socket
- it is interested in is ready for reading.
- Writers are not implemented here, because we only send short
- packets.
- """
- def __init__(self, zc):
- threading.Thread.__init__(self, name='zeroconf-Engine')
- self.daemon = True
- self.zc = zc
- self.readers = {} # maps socket to reader
- self.timeout = 5
- self.condition = threading.Condition()
- self.start()
- def run(self):
- while not self.zc.done:
- with self.condition:
- rs = self.readers.keys()
- if len(rs) == 0:
- # No sockets to manage, but we wait for the timeout
- # or addition of a socket
- self.condition.wait(self.timeout)
- if len(rs) != 0:
- try:
- rr, wr, er = select.select(rs, [], [], self.timeout)
- if not self.zc.done:
- for socket_ in rr:
- reader = self.readers.get(socket_)
- if reader:
- reader.handle_read(socket_)
- except (select.error, socket.error) as e:
- # If the socket was closed by another thread, during
- # shutdown, ignore it and exit
- if e.args[0] != socket.EBADF or not self.zc.done:
- raise
- def add_reader(self, reader, socket_):
- with self.condition:
- self.readers[socket_] = reader
- self.condition.notify()
- def del_reader(self, socket_):
- with self.condition:
- del self.readers[socket_]
- self.condition.notify()
- class Listener(QuietLogger):
- """A Listener is used by this module to listen on the multicast
- group to which DNS messages are sent, allowing the implementation
- to cache information as it arrives.
- It requires registration with an Engine object in order to have
- the read() method called when a socket is available for reading."""
- def __init__(self, zc):
- self.zc = zc
- self.data = None
- def handle_read(self, socket_):
- try:
- data, (addr, port) = socket_.recvfrom(_MAX_MSG_ABSOLUTE)
- except Exception:
- self.log_exception_warning()
- return
- log.debug('Received from %r:%r: %r ', addr, port, data)
- self.data = data
- msg = DNSIncoming(data)
- if not msg.valid:
- pass
- elif msg.is_query():
- # Always multicast responses
- if port == _MDNS_PORT:
- self.zc.handle_query(msg, _MDNS_ADDR, _MDNS_PORT)
- # If it's not a multicast query, reply via unicast
- # and multicast
- elif port == _DNS_PORT:
- self.zc.handle_query(msg, addr, port)
- self.zc.handle_query(msg, _MDNS_ADDR, _MDNS_PORT)
- else:
- self.zc.handle_response(msg)
- class Reaper(threading.Thread):
- """A Reaper is used by this module to remove cache entries that
- have expired."""
- def __init__(self, zc):
- threading.Thread.__init__(self, name='zeroconf-Reaper')
- self.daemon = True
- self.zc = zc
- self.start()
- def run(self):
- while True:
- self.zc.wait(10 * 1000)
- if self.zc.done:
- return
- now = current_time_millis()
- for record in self.zc.cache.entries():
- if record.is_expired(now):
- self.zc.update_record(now, record)
- self.zc.cache.remove(record)
- class Signal:
- def __init__(self):
- self._handlers = []
- def fire(self, **kwargs):
- for h in list(self._handlers):
- h(**kwargs)
- @property
- def registration_interface(self):
- return SignalRegistrationInterface(self._handlers)
- class SignalRegistrationInterface:
- def __init__(self, handlers):
- self._handlers = handlers
- def register_handler(self, handler):
- self._handlers.append(handler)
- return self
- def unregister_handler(self, handler):
- self._handlers.remove(handler)
- return self
- class RecordUpdateListener:
- def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
- raise NotImplementedError()
- class ServiceBrowser(RecordUpdateListener, threading.Thread):
- """Used to browse for a service of a specific type.
- The listener object will have its add_service() and
- remove_service() methods called when this browser
- discovers changes in the services availability."""
- def __init__(self, zc: 'Zeroconf', type_: str, handlers=None, listener=None,
- addr: str = _MDNS_ADDR, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME) -> None:
- """Creates a browser for a specific type"""
- assert handlers or listener, 'You need to specify at least one handler'
- if not type_.endswith(service_type_name(type_, allow_underscores=True)):
- raise BadTypeInNameException
- threading.Thread.__init__(
- self, name='zeroconf-ServiceBrowser_' + type_)
- self.daemon = True
- self.zc = zc
- self.type = type_
- self.addr = addr
- self.port = port
- self.multicast = (self.addr == _MDNS_ADDR)
- self.services = {} # type: Dict[str, DNSRecord]
- self.next_time = current_time_millis()
- self.delay = delay
- self._handlers_to_call = [] # type: List[Callable[[Zeroconf], None]]
- self._service_state_changed = Signal()
- self.done = False
- if hasattr(handlers, 'add_service'):
- listener = handlers
- handlers = None
- handlers = handlers or []
- if listener:
- def on_change(zeroconf, service_type, name, state_change):
- args = (zeroconf, service_type, name)
- if state_change is ServiceStateChange.Added:
- listener.add_service(*args)
- elif state_change is ServiceStateChange.Removed:
- listener.remove_service(*args)
- else:
- raise NotImplementedError(state_change)
- handlers.append(on_change)
- for h in handlers:
- self.service_state_changed.register_handler(h)
- self.start()
- @property
- def service_state_changed(self) -> SignalRegistrationInterface:
- return self._service_state_changed.registration_interface
- def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
- """Callback invoked by Zeroconf when new information arrives.
- Updates information required by browser in the Zeroconf cache."""
- def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
- self._handlers_to_call.append(
- lambda zeroconf: self._service_state_changed.fire(
- zeroconf=zeroconf,
- service_type=self.type,
- name=name,
- state_change=state_change,
- ))
- if record.type == _TYPE_PTR and record.name == self.type:
- assert isinstance(record, DNSPointer)
- expired = record.is_expired(now)
- service_key = record.alias.lower()
- try:
- old_record = self.services[service_key]
- except KeyError:
- if not expired:
- self.services[service_key] = record
- enqueue_callback(ServiceStateChange.Added, record.alias)
- else:
- if not expired:
- old_record.reset_ttl(record)
- else:
- del self.services[service_key]
- enqueue_callback(ServiceStateChange.Removed, record.alias)
- return
- expires = record.get_expiration_time(75)
- if expires < self.next_time:
- self.next_time = expires
- def cancel(self):
- self.done = True
- self.zc.remove_listener(self)
- self.join()
- def run(self):
- self.zc.add_listener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
- while True:
- now = current_time_millis()
- if len(self._handlers_to_call) == 0 and self.next_time > now:
- self.zc.wait(self.next_time - now)
- if self.zc.done or self.done:
- return
- now = current_time_millis()
- if self.next_time <= now:
- out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast)
- out.add_question(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
- for record in self.services.values():
- if not record.is_stale(now):
- out.add_answer_at_time(record, now)
- self.zc.send(out, addr=self.addr, port=self.port)
- self.next_time = now + self.delay
- self.delay = min(20 * 1000, self.delay * 2)
- if len(self._handlers_to_call) > 0 and not self.zc.done:
- handler = self._handlers_to_call.pop(0)
- handler(self.zc)
- ServicePropertiesType = Dict[bytes, Union[bool, str]]
- class ServiceInfo(RecordUpdateListener):
- """Service information"""
- def __init__(self, type_: str, name: str, address: bytes = None, port: int = None, weight: int = 0,
- priority: int = 0, properties=None, server: str = None) -> None:
- """Create a service description.
- type_: fully qualified service type name
- name: fully qualified service name
- address: IP address as unsigned short, network byte order
- port: port that the service runs on
- weight: weight of the service
- priority: priority of the service
- properties: dictionary of properties (or a string holding the
- bytes for the text field)
- server: fully qualified name for service host (defaults to name)"""
- if not type_.endswith(service_type_name(name, allow_underscores=True)):
- raise BadTypeInNameException
- self.type = type_
- self.name = name
- self.address = address
- self.port = port
- self.weight = weight
- self.priority = priority
- if server:
- self.server = server
- else:
- self.server = name
- self._properties = {} # type: ServicePropertiesType
- self._set_properties(properties)
- # FIXME: this is here only so that mypy doesn't complain when we set and then use the attribute when
- # registering services. See if setting this to None by default is the right way to go.
- self.ttl = None # type: Optional[int]
- @property
- def properties(self) -> ServicePropertiesType:
- return self._properties
- def _set_properties(self, properties: Union[bytes, ServicePropertiesType]):
- """Sets properties and text of this info from a dictionary"""
- if isinstance(properties, dict):
- self._properties = properties
- list_ = []
- result = b''
- for key, value in properties.items():
- if isinstance(key, str):
- key = key.encode('utf-8')
- if value is None:
- suffix = b''
- elif isinstance(value, str):
- suffix = value.encode('utf-8')
- elif isinstance(value, bytes):
- suffix = value
- elif isinstance(value, int):
- if value:
- suffix = b'true'
- else:
- suffix = b'false'
- else:
- suffix = b''
- list_.append(b'='.join((key, suffix)))
- for item in list_:
- result = b''.join((result, int2byte(len(item)), item))
- self.text = result
- else:
- self.text = properties
- def _set_text(self, text):
- """Sets properties and text given a text field"""
- self.text = text
- result = {}
- end = len(text)
- index = 0
- strs = []
- while index < end:
- length = text[index]
- index += 1
- strs.append(text[index:index + length])
- index += length
- for s in strs:
- parts = s.split(b'=', 1)
- try:
- key, value = parts
- except ValueError:
- # No equals sign at all
- key = s
- value = False
- else:
- if value == b'true':
- value = True
- elif value == b'false' or not value:
- value = False
- # Only update non-existent properties
- if key and result.get(key) is None:
- result[key] = value
- self._properties = result
- def get_name(self):
- """Name accessor"""
- if self.type is not None and self.name.endswith("." + self.type):
- return self.name[:len(self.name) - len(self.type) - 1]
- return self.name
- def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
- """Updates service information from a DNS record"""
- if record is not None and not record.is_expired(now):
- if record.type == _TYPE_A:
- assert isinstance(record, DNSAddress)
- # if record.name == self.name:
- if record.name == self.server:
- self.address = record.address
- elif record.type == _TYPE_SRV:
- assert isinstance(record, DNSService)
- if record.name == self.name:
- self.server = record.server
- self.port = record.port
- self.weight = record.weight
- self.priority = record.priority
- # self.address = None
- self.update_record(
- zc, now, zc.cache.get_by_details(
- self.server, _TYPE_A, _CLASS_IN))
- elif record.type == _TYPE_TXT:
- assert isinstance(record, DNSText)
- if record.name == self.name:
- self._set_text(record.text)
- def request(self, zc: 'Zeroconf', timeout: float) -> bool:
- """Returns true if the service could be discovered on the
- network, and updates this object with details discovered.
- """
- now = current_time_millis()
- delay = _LISTENER_TIME
- next_ = now + delay
- last = now + timeout
- record_types_for_check_cache = [
- ]
- if self.server is not None:
- record_types_for_check_cache.append((_TYPE_A, _CLASS_IN))
- for record_type in record_types_for_check_cache:
- cached = zc.cache.get_by_details(self.name, *record_type)
- if cached:
- self.update_record(zc, now, cached)
- if None not in (self.server, self.address, self.text):
- return True
- try:
- zc.add_listener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN))
- while None in (self.server, self.address, self.text):
- if last <= now:
- return False
- if next_ <= now:
- out = DNSOutgoing(_FLAGS_QR_QUERY)
- out.add_question(
- DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN))
- out.add_answer_at_time(
- zc.cache.get_by_details(
- self.name, _TYPE_SRV, _CLASS_IN), now)
- out.add_question(
- DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN))
- out.add_answer_at_time(
- zc.cache.get_by_details(
- self.name, _TYPE_TXT, _CLASS_IN), now)
- if self.server is not None:
- out.add_question(
- DNSQuestion(self.server, _TYPE_A, _CLASS_IN))
- out.add_answer_at_time(
- zc.cache.get_by_details(
- self.server, _TYPE_A, _CLASS_IN), now)
- zc.send(out)
- next_ = now + delay
- delay *= 2
- zc.wait(min(next_, last) - now)
- now = current_time_millis()
- finally:
- zc.remove_listener(self)
- return True
- def __eq__(self, other: object) -> bool:
- """Tests equality of service name"""
- return isinstance(other, ServiceInfo) and other.name == self.name
- def __ne__(self, other: object) -> bool:
- """Non-equality test"""
- return not self.__eq__(other)
- def __repr__(self) -> str:
- """String representation"""
- return '%s(%s)' % (
- type(self).__name__,
- ', '.join(
- '%s=%r' % (name, getattr(self, name))
- for name in (
- 'type', 'name', 'address', 'port', 'weight', 'priority',
- 'server', 'properties',
- )
- )
- )
- class ZeroconfServiceTypes:
- """
- Return all of the advertised services on any local networks
- """
- def __init__(self):
- self.found_services = set()
- def add_service(self, zc, type_, name):
- self.found_services.add(name)
- def remove_service(self, zc, type_, name):
- pass
- @classmethod
- def find(cls, zc=None, timeout=5, interfaces=InterfaceChoice.All):
- """
- Return all of the advertised services on any local networks.
- :param zc: Zeroconf() instance. Pass in if already have an
- instance running or if non-default interfaces are needed
- :param timeout: seconds to wait for any responses
- :return: tuple of service type strings
- """
- local_zc = zc or Zeroconf(interfaces=interfaces)
- listener = cls()
- browser = ServiceBrowser(
- local_zc, '_services._dns-sd._udp.local.', listener=listener)
- # wait for responses
- time.sleep(timeout)
- # close down anything we opened
- if zc is None:
- local_zc.close()
- else:
- browser.cancel()
- return tuple(sorted(listener.found_services))
- def get_all_addresses() -> List[str]:
- return list(set(
- addr.ip
- for iface in ifaddr.get_adapters()
- for addr in iface.ips
- if addr.is_IPv4 and addr.network_prefix != 32 # Host only netmask
- ))
- def normalize_interface_choice(choice: Union[List[str], InterfaceChoice]) -> List[str]:
- if choice is InterfaceChoice.Default:
- return ['']
- elif choice is InterfaceChoice.All:
- return get_all_addresses()
- else:
- assert isinstance(choice, list)
- return choice
- def new_socket(port: int = _MDNS_PORT) -> socket.socket:
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- # SO_REUSEADDR should be equivalent to SO_REUSEPORT for
- # multicast UDP sockets (p 731, "TCP/IP Illustrated,
- # Volume 2"), but some BSD-derived systems require
- # SO_REUSEPORT to be specified explicity. Also, not all
- # versions of Python have SO_REUSEPORT available.
- # Catch OSError and socket.error for kernel versions <3.9 because lacking
- # SO_REUSEPORT support.
- try:
- reuseport = socket.SO_REUSEPORT
- except AttributeError:
- pass
- else:
- try:
- s.setsockopt(socket.SOL_SOCKET, reuseport, 1)
- except (OSError, socket.error) as err:
- # OSError on python 3, socket.error on python 2
- if not err.errno == errno.ENOPROTOOPT:
- raise
- if port is _MDNS_PORT:
- # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and
- # IP_MULTICAST_LOOP socket options as an unsigned char.
- ttl = struct.pack(b'B', 255)
- s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl)
- loop = struct.pack(b'B', 1)
- s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop)
- s.bind(('', port))
- return s
- def get_errno(e: Exception) -> int:
- assert isinstance(e, socket.error)
- return e.args[0]
- class Zeroconf(QuietLogger):
- """Implementation of Zeroconf Multicast DNS Service Discovery
- Supports registration, unregistration, queries and browsing.
- """
- def __init__(
- self,
- interfaces: Union[List[str], InterfaceChoice] = InterfaceChoice.All,
- unicast: bool = False
- ) -> None:
- """Creates an instance of the Zeroconf class, establishing
- multicast communications, listening and reaping threads.
- :type interfaces: :class:`InterfaceChoice` or sequence of ip addresses
- """
- # hook for threads
- self._GLOBAL_DONE = False
- self.unicast = unicast
- if not unicast:
- self._listen_socket = new_socket()
- interfaces = normalize_interface_choice(interfaces)
- self._respond_sockets = [] # type: List[socket.socket]
- for i in interfaces:
- if not unicast:
- log.debug('Adding %r to multicast group', i)
- try:
- _value = socket.inet_aton(_MDNS_ADDR) + socket.inet_aton(i)
- self._listen_socket.setsockopt(
- socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value)
- except socket.error as e:
- _errno = get_errno(e)
- if _errno == errno.EADDRINUSE:
- log.info(
- 'Address in use when adding %s to multicast group, '
- 'it is expected to happen on some systems', i,
- )
- elif _errno == errno.EADDRNOTAVAIL:
- log.info(
- 'Address not available when adding %s to multicast '
- 'group, it is expected to happen on some systems', i,
- )
- continue
- elif _errno == errno.EINVAL:
- log.info(
- 'Interface of %s does not support multicast, '
- 'it is expected in WSL', i
- )
- continue
- else:
- raise
- respond_socket = new_socket()
- respond_socket.setsockopt(
- socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(i))
- else:
- respond_socket = new_socket(port=0)
- self._respond_sockets.append(respond_socket)
- self.listeners = [] # type: List[RecordUpdateListener]
- self.browsers = {} # type: Dict[RecordUpdateListener, ServiceBrowser]
- self.services = {} # type: Dict[str, ServiceInfo]
- self.servicetypes = {} # type: Dict[str, int]
- self.cache = DNSCache()
- self.condition = threading.Condition()
- self.engine = Engine(self)
- self.listener = Listener(self)
- if not unicast:
- self.engine.add_reader(self.listener, self._listen_socket)
- else:
- for s in self._respond_sockets:
- self.engine.add_reader(self.listener, s)
- self.reaper = Reaper(self)
- self.debug = None # type: Optional[DNSOutgoing]
- @property
- def done(self) -> bool:
- return self._GLOBAL_DONE
- def wait(self, timeout: float) -> None:
- """Calling thread waits for a given number of milliseconds or
- until notified."""
- with self.condition:
- self.condition.wait(timeout / 1000.0)
- def notify_all(self) -> None:
- """Notifies all waiting threads"""
- with self.condition:
- self.condition.notify_all()
- def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]:
- """Returns network's service information for a particular
- name and type, or None if no service matches by the timeout,
- which defaults to 3 seconds."""
- info = ServiceInfo(type_, name)
- if info.request(self, timeout):
- return info
- return None
- def add_service_listener(self, type_: str, listener: RecordUpdateListener) -> None:
- """Adds a listener for a particular service type. This object
- will then have its update_record method called when information
- arrives for that type."""
- self.remove_service_listener(listener)
- self.browsers[listener] = ServiceBrowser(self, type_, listener)
- def remove_service_listener(self, listener: RecordUpdateListener) -> None:
- """Removes a listener from the set that is currently listening."""
- if listener in self.browsers:
- self.browsers[listener].cancel()
- del self.browsers[listener]
- def remove_all_service_listeners(self) -> None:
- """Removes a listener from the set that is currently listening."""
- for listener in [k for k in self.browsers]:
- self.remove_service_listener(listener)
- def register_service(
- self, info: ServiceInfo, ttl: int = _DNS_TTL, allow_name_change: bool = False,
- ) -> None:
- """Registers service information to the network with a default TTL
- of 60 seconds. Zeroconf will then respond to requests for
- information for that service. The name of the service may be
- changed if needed to make it unique on the network."""
- info.ttl = ttl
- self.check_service(info, allow_name_change)
- self.services[info.name.lower()] = info
- if info.type in self.servicetypes:
- self.servicetypes[info.type] += 1
- else:
- self.servicetypes[info.type] = 1
- now = current_time_millis()
- next_time = now
- i = 0
- while i < 3:
- if now < next_time:
- self.wait(next_time - now)
- now = current_time_millis()
- continue
- out.add_answer_at_time(
- DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, ttl, info.name), 0)
- out.add_answer_at_time(
- DNSService(info.name, _TYPE_SRV, _CLASS_IN,
- ttl, info.priority, info.weight, info.port,
- info.server), 0)
- out.add_answer_at_time(
- DNSText(info.name, _TYPE_TXT, _CLASS_IN, ttl, info.text), 0)
- if info.address:
- out.add_answer_at_time(
- DNSAddress(info.server, _TYPE_A, _CLASS_IN,
- ttl, info.address), 0)
- self.send(out)
- i += 1
- next_time += _REGISTER_TIME
- def unregister_service(self, info: ServiceInfo) -> None:
- """Unregister a service."""
- try:
- del self.services[info.name.lower()]
- if self.servicetypes[info.type] > 1:
- self.servicetypes[info.type] -= 1
- else:
- del self.servicetypes[info.type]
- except Exception as e: # TODO stop catching all Exceptions
- log.exception('Unknown error, possibly benign: %r', e)
- now = current_time_millis()
- next_time = now
- i = 0
- while i < 3:
- if now < next_time:
- self.wait(next_time - now)
- now = current_time_millis()
- continue
- out.add_answer_at_time(
- DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
- out.add_answer_at_time(
- DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0,
- info.priority, info.weight, info.port, info.name), 0)
- out.add_answer_at_time(
- DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
- if info.address:
- out.add_answer_at_time(
- DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0,
- info.address), 0)
- self.send(out)
- i += 1
- next_time += _UNREGISTER_TIME
- def unregister_all_services(self) -> None:
- """Unregister all registered services."""
- if len(self.services) > 0:
- now = current_time_millis()
- next_time = now
- i = 0
- while i < 3:
- if now < next_time:
- self.wait(next_time - now)
- now = current_time_millis()
- continue
- for info in self.services.values():
- out.add_answer_at_time(DNSPointer(
- info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
- out.add_answer_at_time(DNSService(
- info.name, _TYPE_SRV, _CLASS_IN, 0,
- info.priority, info.weight, info.port, info.server), 0)
- out.add_answer_at_time(DNSText(
- info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
- if info.address:
- out.add_answer_at_time(DNSAddress(
- info.server, _TYPE_A, _CLASS_IN, 0,
- info.address), 0)
- self.send(out)
- i += 1
- next_time += _UNREGISTER_TIME
- def check_service(self, info: ServiceInfo, allow_name_change: bool) -> None:
- """Checks the network for a unique service name, modifying the
- ServiceInfo passed in if it is not unique."""
- # This is kind of funky because of the subtype based tests
- # need to make subtypes a first class citizen
- service_name = service_type_name(info.name)
- if not info.type.endswith(service_name):
- raise BadTypeInNameException
- instance_name = info.name[:-len(service_name) - 1]
- next_instance_number = 2
- now = current_time_millis()
- next_time = now
- i = 0
- while i < 3:
- # check for a name conflict
- while self.cache.current_entry_with_name_and_alias(
- info.type, info.name):
- if not allow_name_change:
- raise NonUniqueNameException
- # change the name and look for a conflict
- info.name = '%s-%s.%s' % (
- instance_name, next_instance_number, info.type)
- next_instance_number += 1
- service_type_name(info.name)
- next_time = now
- i = 0
- if now < next_time:
- self.wait(next_time - now)
- now = current_time_millis()
- continue
- out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
- self.debug = out
- out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN))
- out.add_authorative_answer(DNSPointer(
- info.type, _TYPE_PTR, _CLASS_IN, info.ttl, info.name))
- self.send(out)
- i += 1
- next_time += _CHECK_TIME
- def add_listener(self, listener: RecordUpdateListener, question: Optional[DNSQuestion]) -> None:
- """Adds a listener for a given question. The listener will have
- its update_record method called when information is available to
- answer the question."""
- now = current_time_millis()
- self.listeners.append(listener)
- if question is not None:
- for record in self.cache.entries_with_name(question.name):
- if question.answered_by(record) and not record.is_expired(now):
- listener.update_record(self, now, record)
- self.notify_all()
- def remove_listener(self, listener: RecordUpdateListener) -> None:
- """Removes a listener."""
- try:
- self.listeners.remove(listener)
- self.notify_all()
- except Exception as e: # TODO stop catching all Exceptions
- log.exception('Unknown error, possibly benign: %r', e)
- def update_record(self, now: float, rec: DNSRecord) -> None:
- """Used to notify listeners of new information that has updated
- a record."""
- for listener in self.listeners:
- listener.update_record(self, now, rec)
- self.notify_all()
- def handle_response(self, msg: DNSIncoming) -> None:
- """Deal with incoming response packets. All answers
- are held in the cache, and listeners are notified."""
- now = current_time_millis()
- for record in msg.answers:
- expired = record.is_expired(now)
- if record in self.cache.entries():
- if expired:
- self.cache.remove(record)
- else:
- entry = self.cache.get(record)
- if entry is not None:
- entry.reset_ttl(record)
- else:
- self.cache.add(record)
- for record in msg.answers:
- self.update_record(now, record)
- def handle_query(self, msg: DNSIncoming, addr: str, port: int) -> None:
- """Deal with incoming query packets. Provides a response if
- possible."""
- out = None
- # Support unicast client responses
- #
- if port != _MDNS_PORT:
- out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False)
- for question in msg.questions:
- out.add_question(question)
- for question in msg.questions:
- if question.type == _TYPE_PTR:
- if question.name == "_services._dns-sd._udp.local.":
- for stype in self.servicetypes.keys():
- if out is None:
- out.add_answer(msg, DNSPointer(
- "_services._dns-sd._udp.local.", _TYPE_PTR,
- _CLASS_IN, _DNS_TTL, stype))
- for service in self.services.values():
- if question.name == service.type:
- if out is None:
- out.add_answer(msg, DNSPointer(
- service.type, _TYPE_PTR,
- _CLASS_IN, service.ttl, service.name))
- else:
- try:
- if out is None:
- # Answer A record queries for any service addresses we know
- if question.type in (_TYPE_A, _TYPE_ANY):
- for service in self.services.values():
- if service.server == question.name.lower():
- out.add_answer(msg, DNSAddress(
- question.name, _TYPE_A,
- service.ttl, service.address))
- name_to_find = question.name.lower()
- if name_to_find not in self.services:
- continue
- service = self.services[name_to_find]
- if question.type in (_TYPE_SRV, _TYPE_ANY):
- out.add_answer(msg, DNSService(
- question.name, _TYPE_SRV, _CLASS_IN | _CLASS_UNIQUE,
- service.ttl, service.priority, service.weight,
- service.port, service.server))
- if question.type in (_TYPE_TXT, _TYPE_ANY):
- out.add_answer(msg, DNSText(
- question.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE,
- service.ttl, service.text))
- if question.type == _TYPE_SRV:
- out.add_additional_answer(DNSAddress(
- service.server, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE,
- service.ttl, service.address))
- except Exception: # TODO stop catching all Exceptions
- self.log_exception_warning()
- if out is not None and out.answers:
- out.id = msg.id
- self.send(out, addr, port)
- def send(self, out: DNSOutgoing, addr: str = _MDNS_ADDR, port: int = _MDNS_PORT) -> None:
- """Sends an outgoing packet."""
- packet = out.packet()
- if len(packet) > _MAX_MSG_ABSOLUTE:
- self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r",
- out, len(packet), packet)
- return
- log.debug('Sending %r (%d bytes) as %r...', out, len(packet), packet)
- for s in self._respond_sockets:
- if self._GLOBAL_DONE:
- return
- try:
- bytes_sent = s.sendto(packet, 0, (addr, port))
- except Exception: # TODO stop catching all Exceptions
- # on send errors, log the exception and keep going
- self.log_exception_warning()
- else:
- if bytes_sent != len(packet):
- self.log_warning_once(
- '!!! sent %d out of %d bytes to %r' % (
- bytes_sent, len(packet), s))
- def close(self) -> None:
- """Ends the background threads, and prevent this instance from
- servicing further queries."""
- if not self._GLOBAL_DONE:
- self._GLOBAL_DONE = True
- # remove service listeners
- self.remove_all_service_listeners()
- self.unregister_all_services()
- # shutdown recv socket and thread
- if not self.unicast:
- self.engine.del_reader(self._listen_socket)
- self._listen_socket.close()
- else:
- for s in self._respond_sockets:
- self.engine.del_reader(s)
- self.engine.join()
- # shutdown the rest
- self.notify_all()
- self.reaper.join()
- for s in self._respond_sockets:
- s.close()