You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2133 lines
72 KiB

4 years ago
  1. """ Multicast DNS Service Discovery for Python, v0.14-wmcbrine
  2. Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
  3. This module provides a framework for the use of DNS Service Discovery
  4. using IP multicast.
  5. This library is free software; you can redistribute it and/or
  6. modify it under the terms of the GNU Lesser General Public
  7. License as published by the Free Software Foundation; either
  8. version 2.1 of the License, or (at your option) any later version.
  9. This library is distributed in the hope that it will be useful,
  10. but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. Lesser General Public License for more details.
  13. You should have received a copy of the GNU Lesser General Public
  14. License along with this library; if not, write to the Free Software
  15. Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
  16. USA
  17. """
  18. import enum
  19. import errno
  20. import logging
  21. import re
  22. import select
  23. import socket
  24. import struct
  25. import sys
  26. import threading
  27. import time
  28. from functools import reduce
  29. from typing import Callable # noqa # used in type hints
  30. from typing import Dict, List, Optional, Union
  31. import ifaddr
  32. __author__ = 'Paul Scott-Murphy, William McBrine'
  33. __maintainer__ = 'Jakub Stasiak <jakub@stasiak.at>'
  34. __version__ = '0.21.3'
  35. __license__ = 'LGPL'
  36. __all__ = [
  37. "__version__",
  38. "Zeroconf", "ServiceInfo", "ServiceBrowser",
  39. "Error", "InterfaceChoice", "ServiceStateChange",
  40. ]
  41. if sys.version_info <= (3, 3):
  42. raise ImportError('''
  43. Python version > 3.3 required for python-zeroconf.
  44. If you need support for Python 2 or Python 3.3 please use version 19.1
  45. ''')
  46. log = logging.getLogger(__name__)
  47. log.addHandler(logging.NullHandler())
  48. if log.level == logging.NOTSET:
  49. log.setLevel(logging.WARN)
  50. # Some timing constants
  51. _UNREGISTER_TIME = 125
  52. _CHECK_TIME = 175
  53. _REGISTER_TIME = 225
  54. _LISTENER_TIME = 200
  55. _BROWSER_TIME = 500
  56. # Some DNS constants
  57. _MDNS_ADDR = '224.0.0.251'
  58. _MDNS_PORT = 5353
  59. _DNS_PORT = 53
  60. _DNS_TTL = 120 # two minutes default TTL as recommended by RFC6762
  61. _MAX_MSG_TYPICAL = 1460 # unused
  62. _MAX_MSG_ABSOLUTE = 8966
  63. _FLAGS_QR_MASK = 0x8000 # query response mask
  64. _FLAGS_QR_QUERY = 0x0000 # query
  65. _FLAGS_QR_RESPONSE = 0x8000 # response
  66. _FLAGS_AA = 0x0400 # Authoritative answer
  67. _FLAGS_TC = 0x0200 # Truncated
  68. _FLAGS_RD = 0x0100 # Recursion desired
  69. _FLAGS_RA = 0x8000 # Recursion available
  70. _FLAGS_Z = 0x0040 # Zero
  71. _FLAGS_AD = 0x0020 # Authentic data
  72. _FLAGS_CD = 0x0010 # Checking disabled
  73. _CLASS_IN = 1
  74. _CLASS_CS = 2
  75. _CLASS_CH = 3
  76. _CLASS_HS = 4
  77. _CLASS_NONE = 254
  78. _CLASS_ANY = 255
  79. _CLASS_MASK = 0x7FFF
  80. _CLASS_UNIQUE = 0x8000
  81. _TYPE_A = 1
  82. _TYPE_NS = 2
  83. _TYPE_MD = 3
  84. _TYPE_MF = 4
  85. _TYPE_CNAME = 5
  86. _TYPE_SOA = 6
  87. _TYPE_MB = 7
  88. _TYPE_MG = 8
  89. _TYPE_MR = 9
  90. _TYPE_NULL = 10
  91. _TYPE_WKS = 11
  92. _TYPE_PTR = 12
  93. _TYPE_HINFO = 13
  94. _TYPE_MINFO = 14
  95. _TYPE_MX = 15
  96. _TYPE_TXT = 16
  97. _TYPE_AAAA = 28
  98. _TYPE_SRV = 33
  99. _TYPE_ANY = 255
  100. # Mapping constants to names
  101. _CLASSES = {_CLASS_IN: "in",
  102. _CLASS_CS: "cs",
  103. _CLASS_CH: "ch",
  104. _CLASS_HS: "hs",
  105. _CLASS_NONE: "none",
  106. _CLASS_ANY: "any"}
  107. _TYPES = {_TYPE_A: "a",
  108. _TYPE_NS: "ns",
  109. _TYPE_MD: "md",
  110. _TYPE_MF: "mf",
  111. _TYPE_CNAME: "cname",
  112. _TYPE_SOA: "soa",
  113. _TYPE_MB: "mb",
  114. _TYPE_MG: "mg",
  115. _TYPE_MR: "mr",
  116. _TYPE_NULL: "null",
  117. _TYPE_WKS: "wks",
  118. _TYPE_PTR: "ptr",
  119. _TYPE_HINFO: "hinfo",
  120. _TYPE_MINFO: "minfo",
  121. _TYPE_MX: "mx",
  122. _TYPE_TXT: "txt",
  123. _TYPE_AAAA: "quada",
  124. _TYPE_SRV: "srv",
  125. _TYPE_ANY: "any"}
  126. _HAS_A_TO_Z = re.compile(r'[A-Za-z]')
  127. _HAS_ONLY_A_TO_Z_NUM_HYPHEN = re.compile(r'^[A-Za-z0-9\-]+$')
  128. _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$')
  129. _HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]')
  130. int2byte = struct.Struct(">B").pack
  131. @enum.unique
  132. class InterfaceChoice(enum.Enum):
  133. Default = 1
  134. All = 2
  135. @enum.unique
  136. class ServiceStateChange(enum.Enum):
  137. Added = 1
  138. Removed = 2
  139. # utility functions
  140. def current_time_millis() -> float:
  141. """Current system time in milliseconds"""
  142. return time.time() * 1000
  143. def service_type_name(type_, *, allow_underscores: bool = False):
  144. """
  145. Validate a fully qualified service name, instance or subtype. [rfc6763]
  146. Returns fully qualified service name.
  147. Domain names used by mDNS-SD take the following forms:
  148. <sn> . <_tcp|_udp> . local.
  149. <Instance> . <sn> . <_tcp|_udp> . local.
  150. <sub>._sub . <sn> . <_tcp|_udp> . local.
  151. 1) must end with 'local.'
  152. This is true because we are implementing mDNS and since the 'm' means
  153. multi-cast, the 'local.' domain is mandatory.
  154. 2) local is preceded with either '_udp.' or '_tcp.'
  155. 3) service name <sn> precedes <_tcp|_udp>
  156. The rules for Service Names [RFC6335] state that they may be no more
  157. than fifteen characters long (not counting the mandatory underscore),
  158. consisting of only letters, digits, and hyphens, must begin and end
  159. with a letter or digit, must not contain consecutive hyphens, and
  160. must contain at least one letter.
  161. The instance name <Instance> and sub type <sub> may be up to 63 bytes.
  162. The portion of the Service Instance Name is a user-
  163. friendly name consisting of arbitrary Net-Unicode text [RFC5198]. It
  164. MUST NOT contain ASCII control characters (byte values 0x00-0x1F and
  165. 0x7F) [RFC20] but otherwise is allowed to contain any characters,
  166. without restriction, including spaces, uppercase, lowercase,
  167. punctuation -- including dots -- accented characters, non-Roman text,
  168. and anything else that may be represented using Net-Unicode.
  169. :param type_: Type, SubType or service name to validate
  170. :return: fully qualified service name (eg: _http._tcp.local.)
  171. """
  172. if not (type_.endswith('._tcp.local.') or type_.endswith('._udp.local.')):
  173. raise BadTypeInNameException(
  174. "Type '%s' must end with '._tcp.local.' or '._udp.local.'" %
  175. type_)
  176. remaining = type_[:-len('._tcp.local.')].split('.')
  177. name = remaining.pop()
  178. if not name:
  179. raise BadTypeInNameException("No Service name found")
  180. if len(remaining) == 1 and len(remaining[0]) == 0:
  181. raise BadTypeInNameException(
  182. "Type '%s' must not start with '.'" % type_)
  183. if name[0] != '_':
  184. raise BadTypeInNameException(
  185. "Service name (%s) must start with '_'" % name)
  186. # remove leading underscore
  187. name = name[1:]
  188. if len(name) > 15:
  189. raise BadTypeInNameException(
  190. "Service name (%s) must be <= 15 bytes" % name)
  191. if '--' in name:
  192. raise BadTypeInNameException(
  193. "Service name (%s) must not contain '--'" % name)
  194. if '-' in (name[0], name[-1]):
  195. raise BadTypeInNameException(
  196. "Service name (%s) may not start or end with '-'" % name)
  197. if not _HAS_A_TO_Z.search(name):
  198. raise BadTypeInNameException(
  199. "Service name (%s) must contain at least one letter (eg: 'A-Z')" %
  200. name)
  201. allowed_characters_re = (
  202. _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE if allow_underscores
  203. else _HAS_ONLY_A_TO_Z_NUM_HYPHEN
  204. )
  205. if not allowed_characters_re.search(name):
  206. raise BadTypeInNameException(
  207. "Service name (%s) must contain only these characters: "
  208. "A-Z, a-z, 0-9, hyphen ('-')%s" % (name, ", underscore ('_')" if allow_underscores else ""))
  209. if remaining and remaining[-1] == '_sub':
  210. remaining.pop()
  211. if len(remaining) == 0 or len(remaining[0]) == 0:
  212. raise BadTypeInNameException(
  213. "_sub requires a subtype name")
  214. if len(remaining) > 1:
  215. remaining = ['.'.join(remaining)]
  216. if remaining:
  217. length = len(remaining[0].encode('utf-8'))
  218. if length > 63:
  219. raise BadTypeInNameException("Too long: '%s'" % remaining[0])
  220. if _HAS_ASCII_CONTROL_CHARS.search(remaining[0]):
  221. raise BadTypeInNameException(
  222. "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" %
  223. remaining[0])
  224. return '_' + name + type_[-len('._tcp.local.'):]
  225. # Exceptions
  226. class Error(Exception):
  227. pass
  228. class IncomingDecodeError(Error):
  229. pass
  230. class NonUniqueNameException(Error):
  231. pass
  232. class NamePartTooLongException(Error):
  233. pass
  234. class AbstractMethodException(Error):
  235. pass
  236. class BadTypeInNameException(Error):
  237. pass
  238. # implementation classes
  239. class QuietLogger:
  240. _seen_logs = {} # type: Dict[str, tuple]
  241. @classmethod
  242. def log_exception_warning(cls, logger_data=None):
  243. exc_info = sys.exc_info()
  244. exc_str = str(exc_info[1])
  245. if exc_str not in cls._seen_logs:
  246. # log at warning level the first time this is seen
  247. cls._seen_logs[exc_str] = exc_info
  248. logger = log.warning
  249. else:
  250. logger = log.debug
  251. if logger_data is not None:
  252. logger(*logger_data)
  253. logger('Exception occurred:', exc_info=exc_info)
  254. @classmethod
  255. def log_warning_once(cls, *args):
  256. msg_str = args[0]
  257. if msg_str not in cls._seen_logs:
  258. cls._seen_logs[msg_str] = 0
  259. logger = log.warning
  260. else:
  261. logger = log.debug
  262. cls._seen_logs[msg_str] += 1
  263. logger(*args)
  264. class DNSEntry:
  265. """A DNS entry"""
  266. def __init__(self, name, type_, class_):
  267. self.key = name.lower()
  268. self.name = name
  269. self.type = type_
  270. self.class_ = class_ & _CLASS_MASK
  271. self.unique = (class_ & _CLASS_UNIQUE) != 0
  272. def __eq__(self, other):
  273. """Equality test on name, type, and class"""
  274. return (isinstance(other, DNSEntry) and
  275. self.name == other.name and
  276. self.type == other.type and
  277. self.class_ == other.class_)
  278. def __ne__(self, other):
  279. """Non-equality test"""
  280. return not self.__eq__(other)
  281. @staticmethod
  282. def get_class_(class_):
  283. """Class accessor"""
  284. return _CLASSES.get(class_, "?(%s)" % class_)
  285. @staticmethod
  286. def get_type(t):
  287. """Type accessor"""
  288. return _TYPES.get(t, "?(%s)" % t)
  289. def to_string(self, hdr, other):
  290. """String representation with additional information"""
  291. result = "%s[%s,%s" % (hdr, self.get_type(self.type),
  292. self.get_class_(self.class_))
  293. if self.unique:
  294. result += "-unique,"
  295. else:
  296. result += ","
  297. result += self.name
  298. if other is not None:
  299. result += ",%s]" % other
  300. else:
  301. result += "]"
  302. return result
  303. class DNSQuestion(DNSEntry):
  304. """A DNS question entry"""
  305. def __init__(self, name: str, type_: int, class_: int) -> None:
  306. DNSEntry.__init__(self, name, type_, class_)
  307. def answered_by(self, rec: 'DNSRecord') -> bool:
  308. """Returns true if the question is answered by the record"""
  309. return (self.class_ == rec.class_ and
  310. (self.type == rec.type or self.type == _TYPE_ANY) and
  311. self.name == rec.name)
  312. def __repr__(self) -> str:
  313. """String representation"""
  314. return DNSEntry.to_string(self, "question", None)
  315. class DNSRecord(DNSEntry):
  316. """A DNS record - like a DNS entry, but has a TTL"""
  317. def __init__(self, name, type_, class_, ttl):
  318. DNSEntry.__init__(self, name, type_, class_)
  319. self.ttl = ttl
  320. self.created = current_time_millis()
  321. def __eq__(self, other):
  322. """Abstract method"""
  323. raise AbstractMethodException
  324. def __ne__(self, other):
  325. """Non-equality test"""
  326. return not self.__eq__(other)
  327. def suppressed_by(self, msg):
  328. """Returns true if any answer in a message can suffice for the
  329. information held in this record."""
  330. for record in msg.answers:
  331. if self.suppressed_by_answer(record):
  332. return True
  333. return False
  334. def suppressed_by_answer(self, other):
  335. """Returns true if another record has same name, type and class,
  336. and if its TTL is at least half of this record's."""
  337. return self == other and other.ttl > (self.ttl / 2)
  338. def get_expiration_time(self, percent):
  339. """Returns the time at which this record will have expired
  340. by a certain percentage."""
  341. return self.created + (percent * self.ttl * 10)
  342. def get_remaining_ttl(self, now):
  343. """Returns the remaining TTL in seconds."""
  344. return max(0, (self.get_expiration_time(100) - now) / 1000.0)
  345. def is_expired(self, now) -> bool:
  346. """Returns true if this record has expired."""
  347. return self.get_expiration_time(100) <= now
  348. def is_stale(self, now):
  349. """Returns true if this record is at least half way expired."""
  350. return self.get_expiration_time(50) <= now
  351. def reset_ttl(self, other):
  352. """Sets this record's TTL and created time to that of
  353. another record."""
  354. self.created = other.created
  355. self.ttl = other.ttl
  356. def write(self, out):
  357. """Abstract method"""
  358. raise AbstractMethodException
  359. def to_string(self, other):
  360. """String representation with additional information"""
  361. arg = "%s/%s,%s" % (
  362. self.ttl, self.get_remaining_ttl(current_time_millis()), other)
  363. return DNSEntry.to_string(self, "record", arg)
  364. class DNSAddress(DNSRecord):
  365. """A DNS address record"""
  366. def __init__(self, name, type_, class_, ttl, address):
  367. DNSRecord.__init__(self, name, type_, class_, ttl)
  368. self.address = address
  369. def write(self, out):
  370. """Used in constructing an outgoing packet"""
  371. out.write_string(self.address)
  372. def __eq__(self, other):
  373. """Tests equality on address"""
  374. return (isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and
  375. self.address == other.address)
  376. def __ne__(self, other):
  377. """Non-equality test"""
  378. return not self.__eq__(other)
  379. def __repr__(self):
  380. """String representation"""
  381. try:
  382. return str(socket.inet_ntoa(self.address))
  383. except Exception: # TODO stop catching all Exceptions
  384. return str(self.address)
  385. class DNSHinfo(DNSRecord):
  386. """A DNS host information record"""
  387. def __init__(self, name, type_, class_, ttl, cpu, os):
  388. DNSRecord.__init__(self, name, type_, class_, ttl)
  389. try:
  390. self.cpu = cpu.decode('utf-8')
  391. except AttributeError:
  392. self.cpu = cpu
  393. try:
  394. self.os = os.decode('utf-8')
  395. except AttributeError:
  396. self.os = os
  397. def write(self, out):
  398. """Used in constructing an outgoing packet"""
  399. out.write_character_string(self.cpu.encode('utf-8'))
  400. out.write_character_string(self.os.encode('utf-8'))
  401. def __eq__(self, other):
  402. """Tests equality on cpu and os"""
  403. return (isinstance(other, DNSHinfo) and DNSEntry.__eq__(self, other) and
  404. self.cpu == other.cpu and self.os == other.os)
  405. def __ne__(self, other):
  406. """Non-equality test"""
  407. return not self.__eq__(other)
  408. def __repr__(self):
  409. """String representation"""
  410. return self.cpu + " " + self.os
  411. class DNSPointer(DNSRecord):
  412. """A DNS pointer record"""
  413. def __init__(self, name, type_, class_, ttl, alias):
  414. DNSRecord.__init__(self, name, type_, class_, ttl)
  415. self.alias = alias
  416. def write(self, out):
  417. """Used in constructing an outgoing packet"""
  418. out.write_name(self.alias)
  419. def __eq__(self, other):
  420. """Tests equality on alias"""
  421. return (isinstance(other, DNSPointer) and DNSEntry.__eq__(self, other) and
  422. self.alias == other.alias)
  423. def __ne__(self, other):
  424. """Non-equality test"""
  425. return not self.__eq__(other)
  426. def __repr__(self):
  427. """String representation"""
  428. return self.to_string(self.alias)
  429. class DNSText(DNSRecord):
  430. """A DNS text record"""
  431. def __init__(self, name, type_, class_, ttl, text):
  432. assert isinstance(text, (bytes, type(None)))
  433. DNSRecord.__init__(self, name, type_, class_, ttl)
  434. self.text = text
  435. def write(self, out):
  436. """Used in constructing an outgoing packet"""
  437. out.write_string(self.text)
  438. def __eq__(self, other):
  439. """Tests equality on text"""
  440. return (isinstance(other, DNSText) and DNSEntry.__eq__(self, other) and
  441. self.text == other.text)
  442. def __ne__(self, other):
  443. """Non-equality test"""
  444. return not self.__eq__(other)
  445. def __repr__(self):
  446. """String representation"""
  447. if len(self.text) > 10:
  448. return self.to_string(self.text[:7]) + "..."
  449. else:
  450. return self.to_string(self.text)
  451. class DNSService(DNSRecord):
  452. """A DNS service record"""
  453. def __init__(self, name, type_, class_, ttl,
  454. priority, weight, port, server):
  455. DNSRecord.__init__(self, name, type_, class_, ttl)
  456. self.priority = priority
  457. self.weight = weight
  458. self.port = port
  459. self.server = server
  460. def write(self, out):
  461. """Used in constructing an outgoing packet"""
  462. out.write_short(self.priority)
  463. out.write_short(self.weight)
  464. out.write_short(self.port)
  465. out.write_name(self.server)
  466. def __eq__(self, other):
  467. """Tests equality on priority, weight, port and server"""
  468. return (isinstance(other, DNSService) and
  469. DNSEntry.__eq__(self, other) and
  470. self.priority == other.priority and
  471. self.weight == other.weight and
  472. self.port == other.port and
  473. self.server == other.server)
  474. def __ne__(self, other):
  475. """Non-equality test"""
  476. return not self.__eq__(other)
  477. def __repr__(self):
  478. """String representation"""
  479. return self.to_string("%s:%s" % (self.server, self.port))
  480. class DNSIncoming(QuietLogger):
  481. """Object representation of an incoming DNS packet"""
  482. def __init__(self, data):
  483. """Constructor from string holding bytes of packet"""
  484. self.offset = 0
  485. self.data = data
  486. self.questions = []
  487. self.answers = []
  488. self.id = 0
  489. self.flags = 0
  490. self.num_questions = 0
  491. self.num_answers = 0
  492. self.num_authorities = 0
  493. self.num_additionals = 0
  494. self.valid = False
  495. try:
  496. self.read_header()
  497. self.read_questions()
  498. self.read_others()
  499. self.valid = True
  500. except (IndexError, struct.error, IncomingDecodeError):
  501. self.log_exception_warning((
  502. 'Choked at offset %d while unpacking %r', self.offset, data))
  503. def unpack(self, format_):
  504. length = struct.calcsize(format_)
  505. info = struct.unpack(
  506. format_, self.data[self.offset:self.offset + length])
  507. self.offset += length
  508. return info
  509. def read_header(self):
  510. """Reads header portion of packet"""
  511. (self.id, self.flags, self.num_questions, self.num_answers,
  512. self.num_authorities, self.num_additionals) = self.unpack(b'!6H')
  513. def read_questions(self):
  514. """Reads questions section of packet"""
  515. for i in range(self.num_questions):
  516. name = self.read_name()
  517. type_, class_ = self.unpack(b'!HH')
  518. question = DNSQuestion(name, type_, class_)
  519. self.questions.append(question)
  520. # def read_int(self):
  521. # """Reads an integer from the packet"""
  522. # return self.unpack(b'!I')[0]
  523. def read_character_string(self):
  524. """Reads a character string from the packet"""
  525. length = self.data[self.offset]
  526. self.offset += 1
  527. return self.read_string(length)
  528. def read_string(self, length):
  529. """Reads a string of a given length from the packet"""
  530. info = self.data[self.offset:self.offset + length]
  531. self.offset += length
  532. return info
  533. def read_unsigned_short(self):
  534. """Reads an unsigned short from the packet"""
  535. return self.unpack(b'!H')[0]
  536. def read_others(self):
  537. """Reads the answers, authorities and additionals section of the
  538. packet"""
  539. n = self.num_answers + self.num_authorities + self.num_additionals
  540. for i in range(n):
  541. domain = self.read_name()
  542. type_, class_, ttl, length = self.unpack(b'!HHiH')
  543. rec = None
  544. if type_ == _TYPE_A:
  545. rec = DNSAddress(
  546. domain, type_, class_, ttl, self.read_string(4))
  547. elif type_ == _TYPE_CNAME or type_ == _TYPE_PTR:
  548. rec = DNSPointer(
  549. domain, type_, class_, ttl, self.read_name())
  550. elif type_ == _TYPE_TXT:
  551. rec = DNSText(
  552. domain, type_, class_, ttl, self.read_string(length))
  553. elif type_ == _TYPE_SRV:
  554. rec = DNSService(
  555. domain, type_, class_, ttl,
  556. self.read_unsigned_short(), self.read_unsigned_short(),
  557. self.read_unsigned_short(), self.read_name())
  558. elif type_ == _TYPE_HINFO:
  559. rec = DNSHinfo(
  560. domain, type_, class_, ttl,
  561. self.read_character_string(), self.read_character_string())
  562. elif type_ == _TYPE_AAAA:
  563. rec = DNSAddress(
  564. domain, type_, class_, ttl, self.read_string(16))
  565. else:
  566. # Try to ignore types we don't know about
  567. # Skip the payload for the resource record so the next
  568. # records can be parsed correctly
  569. self.offset += length
  570. if rec is not None:
  571. self.answers.append(rec)
  572. def is_query(self) -> bool:
  573. """Returns true if this is a query"""
  574. return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY
  575. def is_response(self):
  576. """Returns true if this is a response"""
  577. return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
  578. def read_utf(self, offset, length):
  579. """Reads a UTF-8 string of a given length from the packet"""
  580. return str(self.data[offset:offset + length], 'utf-8', 'replace')
  581. def read_name(self):
  582. """Reads a domain name from the packet"""
  583. result = ''
  584. off = self.offset
  585. next_ = -1
  586. first = off
  587. while True:
  588. length = self.data[off]
  589. off += 1
  590. if length == 0:
  591. break
  592. t = length & 0xC0
  593. if t == 0x00:
  594. result = ''.join((result, self.read_utf(off, length) + '.'))
  595. off += length
  596. elif t == 0xC0:
  597. if next_ < 0:
  598. next_ = off + 1
  599. off = ((length & 0x3F) << 8) | self.data[off]
  600. if off >= first:
  601. raise IncomingDecodeError(
  602. "Bad domain name (circular) at %s" % (off,))
  603. first = off
  604. else:
  605. raise IncomingDecodeError("Bad domain name at %s" % (off,))
  606. if next_ >= 0:
  607. self.offset = next_
  608. else:
  609. self.offset = off
  610. return result
  611. class DNSOutgoing:
  612. """Object representation of an outgoing packet"""
  613. def __init__(self, flags, multicast=True):
  614. self.finished = False
  615. self.id = 0
  616. self.multicast = multicast
  617. self.flags = flags
  618. self.names = {}
  619. self.data = []
  620. self.size = 12
  621. self.state = self.State.init
  622. self.questions = []
  623. self.answers = []
  624. self.authorities = []
  625. self.additionals = []
  626. def __repr__(self):
  627. return '<DNSOutgoing:{%s}>' % ', '.join([
  628. 'multicast=%s' % self.multicast,
  629. 'flags=%s' % self.flags,
  630. 'questions=%s' % self.questions,
  631. 'answers=%s' % self.answers,
  632. 'authorities=%s' % self.authorities,
  633. 'additionals=%s' % self.additionals,
  634. ])
  635. class State(enum.Enum):
  636. init = 0
  637. finished = 1
  638. def add_question(self, record):
  639. """Adds a question"""
  640. self.questions.append(record)
  641. def add_answer(self, inp, record):
  642. """Adds an answer"""
  643. if not record.suppressed_by(inp):
  644. self.add_answer_at_time(record, 0)
  645. def add_answer_at_time(self, record, now):
  646. """Adds an answer if it does not expire by a certain time"""
  647. if record is not None:
  648. if now == 0 or not record.is_expired(now):
  649. self.answers.append((record, now))
  650. def add_authorative_answer(self, record):
  651. """Adds an authoritative answer"""
  652. self.authorities.append(record)
  653. def add_additional_answer(self, record):
  654. """ Adds an additional answer
  655. From: RFC 6763, DNS-Based Service Discovery, February 2013
  656. 12. DNS Additional Record Generation
  657. DNS has an efficiency feature whereby a DNS server may place
  658. additional records in the additional section of the DNS message.
  659. These additional records are records that the client did not
  660. explicitly request, but the server has reasonable grounds to expect
  661. that the client might request them shortly, so including them can
  662. save the client from having to issue additional queries.
  663. This section recommends which additional records SHOULD be generated
  664. to improve network efficiency, for both Unicast and Multicast DNS-SD
  665. responses.
  666. 12.1. PTR Records
  667. When including a DNS-SD Service Instance Enumeration or Selective
  668. Instance Enumeration (subtype) PTR record in a response packet, the
  669. server/responder SHOULD include the following additional records:
  670. o The SRV record(s) named in the PTR rdata.
  671. o The TXT record(s) named in the PTR rdata.
  672. o All address records (type "A" and "AAAA") named in the SRV rdata.
  673. 12.2. SRV Records
  674. When including an SRV record in a response packet, the
  675. server/responder SHOULD include the following additional records:
  676. o All address records (type "A" and "AAAA") named in the SRV rdata.
  677. """
  678. self.additionals.append(record)
  679. def pack(self, format_, value):
  680. self.data.append(struct.pack(format_, value))
  681. self.size += struct.calcsize(format_)
  682. def write_byte(self, value):
  683. """Writes a single byte to the packet"""
  684. self.pack(b'!c', int2byte(value))
  685. def insert_short(self, index, value):
  686. """Inserts an unsigned short in a certain position in the packet"""
  687. self.data.insert(index, struct.pack(b'!H', value))
  688. self.size += 2
  689. def write_short(self, value):
  690. """Writes an unsigned short to the packet"""
  691. self.pack(b'!H', value)
  692. def write_int(self, value):
  693. """Writes an unsigned integer to the packet"""
  694. self.pack(b'!I', int(value))
  695. def write_string(self, value):
  696. """Writes a string to the packet"""
  697. assert isinstance(value, bytes)
  698. self.data.append(value)
  699. self.size += len(value)
  700. def write_utf(self, s):
  701. """Writes a UTF-8 string of a given length to the packet"""
  702. utfstr = s.encode('utf-8')
  703. length = len(utfstr)
  704. if length > 64:
  705. raise NamePartTooLongException
  706. self.write_byte(length)
  707. self.write_string(utfstr)
  708. def write_character_string(self, value):
  709. assert isinstance(value, bytes)
  710. length = len(value)
  711. if length > 256:
  712. raise NamePartTooLongException
  713. self.write_byte(length)
  714. self.write_string(value)
  715. def write_name(self, name):
  716. """
  717. Write names to packet
  718. 18.14. Name Compression
  719. When generating Multicast DNS messages, implementations SHOULD use
  720. name compression wherever possible to compress the names of resource
  721. records, by replacing some or all of the resource record name with a
  722. compact two-byte reference to an appearance of that data somewhere
  723. earlier in the message [RFC1035].
  724. """
  725. # split name into each label
  726. parts = name.split('.')
  727. if not parts[-1]:
  728. parts.pop()
  729. # construct each suffix
  730. name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))]
  731. # look for an existing name or suffix
  732. for count, sub_name in enumerate(name_suffices):
  733. if sub_name in self.names:
  734. break
  735. else:
  736. count = len(name_suffices)
  737. # note the new names we are saving into the packet
  738. name_length = len(name.encode('utf-8'))
  739. for suffix in name_suffices[:count]:
  740. self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1
  741. # write the new names out.
  742. for part in parts[:count]:
  743. self.write_utf(part)
  744. # if we wrote part of the name, create a pointer to the rest
  745. if count != len(name_suffices):
  746. # Found substring in packet, create pointer
  747. index = self.names[name_suffices[count]]
  748. self.write_byte((index >> 8) | 0xC0)
  749. self.write_byte(index & 0xFF)
  750. else:
  751. # this is the end of a name
  752. self.write_byte(0)
  753. def write_question(self, question):
  754. """Writes a question to the packet"""
  755. self.write_name(question.name)
  756. self.write_short(question.type)
  757. self.write_short(question.class_)
  758. def write_record(self, record, now):
  759. """Writes a record (answer, authoritative answer, additional) to
  760. the packet"""
  761. if self.state == self.State.finished:
  762. return 1
  763. start_data_length, start_size = len(self.data), self.size
  764. self.write_name(record.name)
  765. self.write_short(record.type)
  766. if record.unique and self.multicast:
  767. self.write_short(record.class_ | _CLASS_UNIQUE)
  768. else:
  769. self.write_short(record.class_)
  770. if now == 0:
  771. self.write_int(record.ttl)
  772. else:
  773. self.write_int(record.get_remaining_ttl(now))
  774. index = len(self.data)
  775. # Adjust size for the short we will write before this record
  776. self.size += 2
  777. record.write(self)
  778. self.size -= 2
  779. length = sum((len(d) for d in self.data[index:]))
  780. # Here is the short we adjusted for
  781. self.insert_short(index, length)
  782. # if we go over, then rollback and quit
  783. if self.size > _MAX_MSG_ABSOLUTE:
  784. while len(self.data) > start_data_length:
  785. self.data.pop()
  786. self.size = start_size
  787. self.state = self.State.finished
  788. return 1
  789. return 0
  790. def packet(self) -> bytes:
  791. """Returns a string containing the packet's bytes
  792. No further parts should be added to the packet once this
  793. is done."""
  794. overrun_answers, overrun_authorities, overrun_additionals = 0, 0, 0
  795. if self.state != self.State.finished:
  796. for question in self.questions:
  797. self.write_question(question)
  798. for answer, time_ in self.answers:
  799. overrun_answers += self.write_record(answer, time_)
  800. for authority in self.authorities:
  801. overrun_authorities += self.write_record(authority, 0)
  802. for additional in self.additionals:
  803. overrun_additionals += self.write_record(additional, 0)
  804. self.state = self.State.finished
  805. self.insert_short(0, len(self.additionals) - overrun_additionals)
  806. self.insert_short(0, len(self.authorities) - overrun_authorities)
  807. self.insert_short(0, len(self.answers) - overrun_answers)
  808. self.insert_short(0, len(self.questions))
  809. self.insert_short(0, self.flags)
  810. if self.multicast:
  811. self.insert_short(0, 0)
  812. else:
  813. self.insert_short(0, self.id)
  814. return b''.join(self.data)
  815. class DNSCache:
  816. """A cache of DNS entries"""
  817. def __init__(self):
  818. self.cache = {}
  819. def add(self, entry):
  820. """Adds an entry"""
  821. # Insert first in list so get returns newest entry
  822. self.cache.setdefault(entry.key, []).insert(0, entry)
  823. def remove(self, entry):
  824. """Removes an entry"""
  825. try:
  826. list_ = self.cache[entry.key]
  827. list_.remove(entry)
  828. except (KeyError, ValueError):
  829. pass
  830. def get(self, entry):
  831. """Gets an entry by key. Will return None if there is no
  832. matching entry."""
  833. try:
  834. list_ = self.cache[entry.key]
  835. for cached_entry in list_:
  836. if entry.__eq__(cached_entry):
  837. return cached_entry
  838. except (KeyError, ValueError):
  839. return None
  840. def get_by_details(self, name, type_, class_):
  841. """Gets an entry by details. Will return None if there is
  842. no matching entry."""
  843. entry = DNSEntry(name, type_, class_)
  844. return self.get(entry)
  845. def entries_with_name(self, name):
  846. """Returns a list of entries whose key matches the name."""
  847. try:
  848. return self.cache[name.lower()]
  849. except KeyError:
  850. return []
  851. def current_entry_with_name_and_alias(self, name, alias):
  852. now = current_time_millis()
  853. for record in self.entries_with_name(name):
  854. if (record.type == _TYPE_PTR and
  855. not record.is_expired(now) and
  856. record.alias == alias):
  857. return record
  858. def entries(self):
  859. """Returns a list of all entries"""
  860. if not self.cache:
  861. return []
  862. else:
  863. # avoid size change during iteration by copying the cache
  864. values = list(self.cache.values())
  865. return reduce(lambda a, b: a + b, values)
  866. class Engine(threading.Thread):
  867. """An engine wraps read access to sockets, allowing objects that
  868. need to receive data from sockets to be called back when the
  869. sockets are ready.
  870. A reader needs a handle_read() method, which is called when the socket
  871. it is interested in is ready for reading.
  872. Writers are not implemented here, because we only send short
  873. packets.
  874. """
  875. def __init__(self, zc):
  876. threading.Thread.__init__(self, name='zeroconf-Engine')
  877. self.daemon = True
  878. self.zc = zc
  879. self.readers = {} # maps socket to reader
  880. self.timeout = 5
  881. self.condition = threading.Condition()
  882. self.start()
  883. def run(self):
  884. while not self.zc.done:
  885. with self.condition:
  886. rs = self.readers.keys()
  887. if len(rs) == 0:
  888. # No sockets to manage, but we wait for the timeout
  889. # or addition of a socket
  890. self.condition.wait(self.timeout)
  891. if len(rs) != 0:
  892. try:
  893. rr, wr, er = select.select(rs, [], [], self.timeout)
  894. if not self.zc.done:
  895. for socket_ in rr:
  896. reader = self.readers.get(socket_)
  897. if reader:
  898. reader.handle_read(socket_)
  899. except (select.error, socket.error) as e:
  900. # If the socket was closed by another thread, during
  901. # shutdown, ignore it and exit
  902. if e.args[0] != socket.EBADF or not self.zc.done:
  903. raise
  904. def add_reader(self, reader, socket_):
  905. with self.condition:
  906. self.readers[socket_] = reader
  907. self.condition.notify()
  908. def del_reader(self, socket_):
  909. with self.condition:
  910. del self.readers[socket_]
  911. self.condition.notify()
  912. class Listener(QuietLogger):
  913. """A Listener is used by this module to listen on the multicast
  914. group to which DNS messages are sent, allowing the implementation
  915. to cache information as it arrives.
  916. It requires registration with an Engine object in order to have
  917. the read() method called when a socket is available for reading."""
  918. def __init__(self, zc):
  919. self.zc = zc
  920. self.data = None
  921. def handle_read(self, socket_):
  922. try:
  923. data, (addr, port) = socket_.recvfrom(_MAX_MSG_ABSOLUTE)
  924. except Exception:
  925. self.log_exception_warning()
  926. return
  927. log.debug('Received from %r:%r: %r ', addr, port, data)
  928. self.data = data
  929. msg = DNSIncoming(data)
  930. if not msg.valid:
  931. pass
  932. elif msg.is_query():
  933. # Always multicast responses
  934. if port == _MDNS_PORT:
  935. self.zc.handle_query(msg, _MDNS_ADDR, _MDNS_PORT)
  936. # If it's not a multicast query, reply via unicast
  937. # and multicast
  938. elif port == _DNS_PORT:
  939. self.zc.handle_query(msg, addr, port)
  940. self.zc.handle_query(msg, _MDNS_ADDR, _MDNS_PORT)
  941. else:
  942. self.zc.handle_response(msg)
  943. class Reaper(threading.Thread):
  944. """A Reaper is used by this module to remove cache entries that
  945. have expired."""
  946. def __init__(self, zc):
  947. threading.Thread.__init__(self, name='zeroconf-Reaper')
  948. self.daemon = True
  949. self.zc = zc
  950. self.start()
  951. def run(self):
  952. while True:
  953. self.zc.wait(10 * 1000)
  954. if self.zc.done:
  955. return
  956. now = current_time_millis()
  957. for record in self.zc.cache.entries():
  958. if record.is_expired(now):
  959. self.zc.update_record(now, record)
  960. self.zc.cache.remove(record)
  961. class Signal:
  962. def __init__(self):
  963. self._handlers = []
  964. def fire(self, **kwargs):
  965. for h in list(self._handlers):
  966. h(**kwargs)
  967. @property
  968. def registration_interface(self):
  969. return SignalRegistrationInterface(self._handlers)
  970. class SignalRegistrationInterface:
  971. def __init__(self, handlers):
  972. self._handlers = handlers
  973. def register_handler(self, handler):
  974. self._handlers.append(handler)
  975. return self
  976. def unregister_handler(self, handler):
  977. self._handlers.remove(handler)
  978. return self
  979. class RecordUpdateListener:
  980. def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
  981. raise NotImplementedError()
  982. class ServiceBrowser(RecordUpdateListener, threading.Thread):
  983. """Used to browse for a service of a specific type.
  984. The listener object will have its add_service() and
  985. remove_service() methods called when this browser
  986. discovers changes in the services availability."""
  987. def __init__(self, zc: 'Zeroconf', type_: str, handlers=None, listener=None,
  988. addr: str = _MDNS_ADDR, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME) -> None:
  989. """Creates a browser for a specific type"""
  990. assert handlers or listener, 'You need to specify at least one handler'
  991. if not type_.endswith(service_type_name(type_, allow_underscores=True)):
  992. raise BadTypeInNameException
  993. threading.Thread.__init__(
  994. self, name='zeroconf-ServiceBrowser_' + type_)
  995. self.daemon = True
  996. self.zc = zc
  997. self.type = type_
  998. self.addr = addr
  999. self.port = port
  1000. self.multicast = (self.addr == _MDNS_ADDR)
  1001. self.services = {} # type: Dict[str, DNSRecord]
  1002. self.next_time = current_time_millis()
  1003. self.delay = delay
  1004. self._handlers_to_call = [] # type: List[Callable[[Zeroconf], None]]
  1005. self._service_state_changed = Signal()
  1006. self.done = False
  1007. if hasattr(handlers, 'add_service'):
  1008. listener = handlers
  1009. handlers = None
  1010. handlers = handlers or []
  1011. if listener:
  1012. def on_change(zeroconf, service_type, name, state_change):
  1013. args = (zeroconf, service_type, name)
  1014. if state_change is ServiceStateChange.Added:
  1015. listener.add_service(*args)
  1016. elif state_change is ServiceStateChange.Removed:
  1017. listener.remove_service(*args)
  1018. else:
  1019. raise NotImplementedError(state_change)
  1020. handlers.append(on_change)
  1021. for h in handlers:
  1022. self.service_state_changed.register_handler(h)
  1023. self.start()
  1024. @property
  1025. def service_state_changed(self) -> SignalRegistrationInterface:
  1026. return self._service_state_changed.registration_interface
  1027. def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
  1028. """Callback invoked by Zeroconf when new information arrives.
  1029. Updates information required by browser in the Zeroconf cache."""
  1030. def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
  1031. self._handlers_to_call.append(
  1032. lambda zeroconf: self._service_state_changed.fire(
  1033. zeroconf=zeroconf,
  1034. service_type=self.type,
  1035. name=name,
  1036. state_change=state_change,
  1037. ))
  1038. if record.type == _TYPE_PTR and record.name == self.type:
  1039. assert isinstance(record, DNSPointer)
  1040. expired = record.is_expired(now)
  1041. service_key = record.alias.lower()
  1042. try:
  1043. old_record = self.services[service_key]
  1044. except KeyError:
  1045. if not expired:
  1046. self.services[service_key] = record
  1047. enqueue_callback(ServiceStateChange.Added, record.alias)
  1048. else:
  1049. if not expired:
  1050. old_record.reset_ttl(record)
  1051. else:
  1052. del self.services[service_key]
  1053. enqueue_callback(ServiceStateChange.Removed, record.alias)
  1054. return
  1055. expires = record.get_expiration_time(75)
  1056. if expires < self.next_time:
  1057. self.next_time = expires
  1058. def cancel(self):
  1059. self.done = True
  1060. self.zc.remove_listener(self)
  1061. self.join()
  1062. def run(self):
  1063. self.zc.add_listener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
  1064. while True:
  1065. now = current_time_millis()
  1066. if len(self._handlers_to_call) == 0 and self.next_time > now:
  1067. self.zc.wait(self.next_time - now)
  1068. if self.zc.done or self.done:
  1069. return
  1070. now = current_time_millis()
  1071. if self.next_time <= now:
  1072. out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast)
  1073. out.add_question(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
  1074. for record in self.services.values():
  1075. if not record.is_stale(now):
  1076. out.add_answer_at_time(record, now)
  1077. self.zc.send(out, addr=self.addr, port=self.port)
  1078. self.next_time = now + self.delay
  1079. self.delay = min(20 * 1000, self.delay * 2)
  1080. if len(self._handlers_to_call) > 0 and not self.zc.done:
  1081. handler = self._handlers_to_call.pop(0)
  1082. handler(self.zc)
  1083. ServicePropertiesType = Dict[bytes, Union[bool, str]]
  1084. class ServiceInfo(RecordUpdateListener):
  1085. """Service information"""
  1086. def __init__(self, type_: str, name: str, address: bytes = None, port: int = None, weight: int = 0,
  1087. priority: int = 0, properties=None, server: str = None) -> None:
  1088. """Create a service description.
  1089. type_: fully qualified service type name
  1090. name: fully qualified service name
  1091. address: IP address as unsigned short, network byte order
  1092. port: port that the service runs on
  1093. weight: weight of the service
  1094. priority: priority of the service
  1095. properties: dictionary of properties (or a string holding the
  1096. bytes for the text field)
  1097. server: fully qualified name for service host (defaults to name)"""
  1098. if not type_.endswith(service_type_name(name, allow_underscores=True)):
  1099. raise BadTypeInNameException
  1100. self.type = type_
  1101. self.name = name
  1102. self.address = address
  1103. self.port = port
  1104. self.weight = weight
  1105. self.priority = priority
  1106. if server:
  1107. self.server = server
  1108. else:
  1109. self.server = name
  1110. self._properties = {} # type: ServicePropertiesType
  1111. self._set_properties(properties)
  1112. # FIXME: this is here only so that mypy doesn't complain when we set and then use the attribute when
  1113. # registering services. See if setting this to None by default is the right way to go.
  1114. self.ttl = None # type: Optional[int]
  1115. @property
  1116. def properties(self) -> ServicePropertiesType:
  1117. return self._properties
  1118. def _set_properties(self, properties: Union[bytes, ServicePropertiesType]):
  1119. """Sets properties and text of this info from a dictionary"""
  1120. if isinstance(properties, dict):
  1121. self._properties = properties
  1122. list_ = []
  1123. result = b''
  1124. for key, value in properties.items():
  1125. if isinstance(key, str):
  1126. key = key.encode('utf-8')
  1127. if value is None:
  1128. suffix = b''
  1129. elif isinstance(value, str):
  1130. suffix = value.encode('utf-8')
  1131. elif isinstance(value, bytes):
  1132. suffix = value
  1133. elif isinstance(value, int):
  1134. if value:
  1135. suffix = b'true'
  1136. else:
  1137. suffix = b'false'
  1138. else:
  1139. suffix = b''
  1140. list_.append(b'='.join((key, suffix)))
  1141. for item in list_:
  1142. result = b''.join((result, int2byte(len(item)), item))
  1143. self.text = result
  1144. else:
  1145. self.text = properties
  1146. def _set_text(self, text):
  1147. """Sets properties and text given a text field"""
  1148. self.text = text
  1149. result = {}
  1150. end = len(text)
  1151. index = 0
  1152. strs = []
  1153. while index < end:
  1154. length = text[index]
  1155. index += 1
  1156. strs.append(text[index:index + length])
  1157. index += length
  1158. for s in strs:
  1159. parts = s.split(b'=', 1)
  1160. try:
  1161. key, value = parts
  1162. except ValueError:
  1163. # No equals sign at all
  1164. key = s
  1165. value = False
  1166. else:
  1167. if value == b'true':
  1168. value = True
  1169. elif value == b'false' or not value:
  1170. value = False
  1171. # Only update non-existent properties
  1172. if key and result.get(key) is None:
  1173. result[key] = value
  1174. self._properties = result
  1175. def get_name(self):
  1176. """Name accessor"""
  1177. if self.type is not None and self.name.endswith("." + self.type):
  1178. return self.name[:len(self.name) - len(self.type) - 1]
  1179. return self.name
  1180. def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
  1181. """Updates service information from a DNS record"""
  1182. if record is not None and not record.is_expired(now):
  1183. if record.type == _TYPE_A:
  1184. assert isinstance(record, DNSAddress)
  1185. # if record.name == self.name:
  1186. if record.name == self.server:
  1187. self.address = record.address
  1188. elif record.type == _TYPE_SRV:
  1189. assert isinstance(record, DNSService)
  1190. if record.name == self.name:
  1191. self.server = record.server
  1192. self.port = record.port
  1193. self.weight = record.weight
  1194. self.priority = record.priority
  1195. # self.address = None
  1196. self.update_record(
  1197. zc, now, zc.cache.get_by_details(
  1198. self.server, _TYPE_A, _CLASS_IN))
  1199. elif record.type == _TYPE_TXT:
  1200. assert isinstance(record, DNSText)
  1201. if record.name == self.name:
  1202. self._set_text(record.text)
  1203. def request(self, zc: 'Zeroconf', timeout: float) -> bool:
  1204. """Returns true if the service could be discovered on the
  1205. network, and updates this object with details discovered.
  1206. """
  1207. now = current_time_millis()
  1208. delay = _LISTENER_TIME
  1209. next_ = now + delay
  1210. last = now + timeout
  1211. record_types_for_check_cache = [
  1212. (_TYPE_SRV, _CLASS_IN),
  1213. (_TYPE_TXT, _CLASS_IN),
  1214. ]
  1215. if self.server is not None:
  1216. record_types_for_check_cache.append((_TYPE_A, _CLASS_IN))
  1217. for record_type in record_types_for_check_cache:
  1218. cached = zc.cache.get_by_details(self.name, *record_type)
  1219. if cached:
  1220. self.update_record(zc, now, cached)
  1221. if None not in (self.server, self.address, self.text):
  1222. return True
  1223. try:
  1224. zc.add_listener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN))
  1225. while None in (self.server, self.address, self.text):
  1226. if last <= now:
  1227. return False
  1228. if next_ <= now:
  1229. out = DNSOutgoing(_FLAGS_QR_QUERY)
  1230. out.add_question(
  1231. DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN))
  1232. out.add_answer_at_time(
  1233. zc.cache.get_by_details(
  1234. self.name, _TYPE_SRV, _CLASS_IN), now)
  1235. out.add_question(
  1236. DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN))
  1237. out.add_answer_at_time(
  1238. zc.cache.get_by_details(
  1239. self.name, _TYPE_TXT, _CLASS_IN), now)
  1240. if self.server is not None:
  1241. out.add_question(
  1242. DNSQuestion(self.server, _TYPE_A, _CLASS_IN))
  1243. out.add_answer_at_time(
  1244. zc.cache.get_by_details(
  1245. self.server, _TYPE_A, _CLASS_IN), now)
  1246. zc.send(out)
  1247. next_ = now + delay
  1248. delay *= 2
  1249. zc.wait(min(next_, last) - now)
  1250. now = current_time_millis()
  1251. finally:
  1252. zc.remove_listener(self)
  1253. return True
  1254. def __eq__(self, other: object) -> bool:
  1255. """Tests equality of service name"""
  1256. return isinstance(other, ServiceInfo) and other.name == self.name
  1257. def __ne__(self, other: object) -> bool:
  1258. """Non-equality test"""
  1259. return not self.__eq__(other)
  1260. def __repr__(self) -> str:
  1261. """String representation"""
  1262. return '%s(%s)' % (
  1263. type(self).__name__,
  1264. ', '.join(
  1265. '%s=%r' % (name, getattr(self, name))
  1266. for name in (
  1267. 'type', 'name', 'address', 'port', 'weight', 'priority',
  1268. 'server', 'properties',
  1269. )
  1270. )
  1271. )
  1272. class ZeroconfServiceTypes:
  1273. """
  1274. Return all of the advertised services on any local networks
  1275. """
  1276. def __init__(self):
  1277. self.found_services = set()
  1278. def add_service(self, zc, type_, name):
  1279. self.found_services.add(name)
  1280. def remove_service(self, zc, type_, name):
  1281. pass
  1282. @classmethod
  1283. def find(cls, zc=None, timeout=5, interfaces=InterfaceChoice.All):
  1284. """
  1285. Return all of the advertised services on any local networks.
  1286. :param zc: Zeroconf() instance. Pass in if already have an
  1287. instance running or if non-default interfaces are needed
  1288. :param timeout: seconds to wait for any responses
  1289. :return: tuple of service type strings
  1290. """
  1291. local_zc = zc or Zeroconf(interfaces=interfaces)
  1292. listener = cls()
  1293. browser = ServiceBrowser(
  1294. local_zc, '_services._dns-sd._udp.local.', listener=listener)
  1295. # wait for responses
  1296. time.sleep(timeout)
  1297. # close down anything we opened
  1298. if zc is None:
  1299. local_zc.close()
  1300. else:
  1301. browser.cancel()
  1302. return tuple(sorted(listener.found_services))
  1303. def get_all_addresses() -> List[str]:
  1304. return list(set(
  1305. addr.ip
  1306. for iface in ifaddr.get_adapters()
  1307. for addr in iface.ips
  1308. if addr.is_IPv4 and addr.network_prefix != 32 # Host only netmask 255.255.255.255
  1309. ))
  1310. def normalize_interface_choice(choice: Union[List[str], InterfaceChoice]) -> List[str]:
  1311. if choice is InterfaceChoice.Default:
  1312. return ['0.0.0.0']
  1313. elif choice is InterfaceChoice.All:
  1314. return get_all_addresses()
  1315. else:
  1316. assert isinstance(choice, list)
  1317. return choice
  1318. def new_socket(port: int = _MDNS_PORT) -> socket.socket:
  1319. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  1320. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  1321. # SO_REUSEADDR should be equivalent to SO_REUSEPORT for
  1322. # multicast UDP sockets (p 731, "TCP/IP Illustrated,
  1323. # Volume 2"), but some BSD-derived systems require
  1324. # SO_REUSEPORT to be specified explicity. Also, not all
  1325. # versions of Python have SO_REUSEPORT available.
  1326. # Catch OSError and socket.error for kernel versions <3.9 because lacking
  1327. # SO_REUSEPORT support.
  1328. try:
  1329. reuseport = socket.SO_REUSEPORT
  1330. except AttributeError:
  1331. pass
  1332. else:
  1333. try:
  1334. s.setsockopt(socket.SOL_SOCKET, reuseport, 1)
  1335. except (OSError, socket.error) as err:
  1336. # OSError on python 3, socket.error on python 2
  1337. if not err.errno == errno.ENOPROTOOPT:
  1338. raise
  1339. if port is _MDNS_PORT:
  1340. # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and
  1341. # IP_MULTICAST_LOOP socket options as an unsigned char.
  1342. ttl = struct.pack(b'B', 255)
  1343. s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl)
  1344. loop = struct.pack(b'B', 1)
  1345. s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop)
  1346. s.bind(('', port))
  1347. return s
  1348. def get_errno(e: Exception) -> int:
  1349. assert isinstance(e, socket.error)
  1350. return e.args[0]
  1351. class Zeroconf(QuietLogger):
  1352. """Implementation of Zeroconf Multicast DNS Service Discovery
  1353. Supports registration, unregistration, queries and browsing.
  1354. """
  1355. def __init__(
  1356. self,
  1357. interfaces: Union[List[str], InterfaceChoice] = InterfaceChoice.All,
  1358. unicast: bool = False
  1359. ) -> None:
  1360. """Creates an instance of the Zeroconf class, establishing
  1361. multicast communications, listening and reaping threads.
  1362. :type interfaces: :class:`InterfaceChoice` or sequence of ip addresses
  1363. """
  1364. # hook for threads
  1365. self._GLOBAL_DONE = False
  1366. self.unicast = unicast
  1367. if not unicast:
  1368. self._listen_socket = new_socket()
  1369. interfaces = normalize_interface_choice(interfaces)
  1370. self._respond_sockets = [] # type: List[socket.socket]
  1371. for i in interfaces:
  1372. if not unicast:
  1373. log.debug('Adding %r to multicast group', i)
  1374. try:
  1375. _value = socket.inet_aton(_MDNS_ADDR) + socket.inet_aton(i)
  1376. self._listen_socket.setsockopt(
  1377. socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value)
  1378. except socket.error as e:
  1379. _errno = get_errno(e)
  1380. if _errno == errno.EADDRINUSE:
  1381. log.info(
  1382. 'Address in use when adding %s to multicast group, '
  1383. 'it is expected to happen on some systems', i,
  1384. )
  1385. elif _errno == errno.EADDRNOTAVAIL:
  1386. log.info(
  1387. 'Address not available when adding %s to multicast '
  1388. 'group, it is expected to happen on some systems', i,
  1389. )
  1390. continue
  1391. elif _errno == errno.EINVAL:
  1392. log.info(
  1393. 'Interface of %s does not support multicast, '
  1394. 'it is expected in WSL', i
  1395. )
  1396. continue
  1397. else:
  1398. raise
  1399. respond_socket = new_socket()
  1400. respond_socket.setsockopt(
  1401. socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(i))
  1402. else:
  1403. respond_socket = new_socket(port=0)
  1404. self._respond_sockets.append(respond_socket)
  1405. self.listeners = [] # type: List[RecordUpdateListener]
  1406. self.browsers = {} # type: Dict[RecordUpdateListener, ServiceBrowser]
  1407. self.services = {} # type: Dict[str, ServiceInfo]
  1408. self.servicetypes = {} # type: Dict[str, int]
  1409. self.cache = DNSCache()
  1410. self.condition = threading.Condition()
  1411. self.engine = Engine(self)
  1412. self.listener = Listener(self)
  1413. if not unicast:
  1414. self.engine.add_reader(self.listener, self._listen_socket)
  1415. else:
  1416. for s in self._respond_sockets:
  1417. self.engine.add_reader(self.listener, s)
  1418. self.reaper = Reaper(self)
  1419. self.debug = None # type: Optional[DNSOutgoing]
  1420. @property
  1421. def done(self) -> bool:
  1422. return self._GLOBAL_DONE
  1423. def wait(self, timeout: float) -> None:
  1424. """Calling thread waits for a given number of milliseconds or
  1425. until notified."""
  1426. with self.condition:
  1427. self.condition.wait(timeout / 1000.0)
  1428. def notify_all(self) -> None:
  1429. """Notifies all waiting threads"""
  1430. with self.condition:
  1431. self.condition.notify_all()
  1432. def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]:
  1433. """Returns network's service information for a particular
  1434. name and type, or None if no service matches by the timeout,
  1435. which defaults to 3 seconds."""
  1436. info = ServiceInfo(type_, name)
  1437. if info.request(self, timeout):
  1438. return info
  1439. return None
  1440. def add_service_listener(self, type_: str, listener: RecordUpdateListener) -> None:
  1441. """Adds a listener for a particular service type. This object
  1442. will then have its update_record method called when information
  1443. arrives for that type."""
  1444. self.remove_service_listener(listener)
  1445. self.browsers[listener] = ServiceBrowser(self, type_, listener)
  1446. def remove_service_listener(self, listener: RecordUpdateListener) -> None:
  1447. """Removes a listener from the set that is currently listening."""
  1448. if listener in self.browsers:
  1449. self.browsers[listener].cancel()
  1450. del self.browsers[listener]
  1451. def remove_all_service_listeners(self) -> None:
  1452. """Removes a listener from the set that is currently listening."""
  1453. for listener in [k for k in self.browsers]:
  1454. self.remove_service_listener(listener)
  1455. def register_service(
  1456. self, info: ServiceInfo, ttl: int = _DNS_TTL, allow_name_change: bool = False,
  1457. ) -> None:
  1458. """Registers service information to the network with a default TTL
  1459. of 60 seconds. Zeroconf will then respond to requests for
  1460. information for that service. The name of the service may be
  1461. changed if needed to make it unique on the network."""
  1462. info.ttl = ttl
  1463. self.check_service(info, allow_name_change)
  1464. self.services[info.name.lower()] = info
  1465. if info.type in self.servicetypes:
  1466. self.servicetypes[info.type] += 1
  1467. else:
  1468. self.servicetypes[info.type] = 1
  1469. now = current_time_millis()
  1470. next_time = now
  1471. i = 0
  1472. while i < 3:
  1473. if now < next_time:
  1474. self.wait(next_time - now)
  1475. now = current_time_millis()
  1476. continue
  1477. out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
  1478. out.add_answer_at_time(
  1479. DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, ttl, info.name), 0)
  1480. out.add_answer_at_time(
  1481. DNSService(info.name, _TYPE_SRV, _CLASS_IN,
  1482. ttl, info.priority, info.weight, info.port,
  1483. info.server), 0)
  1484. out.add_answer_at_time(
  1485. DNSText(info.name, _TYPE_TXT, _CLASS_IN, ttl, info.text), 0)
  1486. if info.address:
  1487. out.add_answer_at_time(
  1488. DNSAddress(info.server, _TYPE_A, _CLASS_IN,
  1489. ttl, info.address), 0)
  1490. self.send(out)
  1491. i += 1
  1492. next_time += _REGISTER_TIME
  1493. def unregister_service(self, info: ServiceInfo) -> None:
  1494. """Unregister a service."""
  1495. try:
  1496. del self.services[info.name.lower()]
  1497. if self.servicetypes[info.type] > 1:
  1498. self.servicetypes[info.type] -= 1
  1499. else:
  1500. del self.servicetypes[info.type]
  1501. except Exception as e: # TODO stop catching all Exceptions
  1502. log.exception('Unknown error, possibly benign: %r', e)
  1503. now = current_time_millis()
  1504. next_time = now
  1505. i = 0
  1506. while i < 3:
  1507. if now < next_time:
  1508. self.wait(next_time - now)
  1509. now = current_time_millis()
  1510. continue
  1511. out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
  1512. out.add_answer_at_time(
  1513. DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
  1514. out.add_answer_at_time(
  1515. DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0,
  1516. info.priority, info.weight, info.port, info.name), 0)
  1517. out.add_answer_at_time(
  1518. DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
  1519. if info.address:
  1520. out.add_answer_at_time(
  1521. DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0,
  1522. info.address), 0)
  1523. self.send(out)
  1524. i += 1
  1525. next_time += _UNREGISTER_TIME
  1526. def unregister_all_services(self) -> None:
  1527. """Unregister all registered services."""
  1528. if len(self.services) > 0:
  1529. now = current_time_millis()
  1530. next_time = now
  1531. i = 0
  1532. while i < 3:
  1533. if now < next_time:
  1534. self.wait(next_time - now)
  1535. now = current_time_millis()
  1536. continue
  1537. out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
  1538. for info in self.services.values():
  1539. out.add_answer_at_time(DNSPointer(
  1540. info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
  1541. out.add_answer_at_time(DNSService(
  1542. info.name, _TYPE_SRV, _CLASS_IN, 0,
  1543. info.priority, info.weight, info.port, info.server), 0)
  1544. out.add_answer_at_time(DNSText(
  1545. info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
  1546. if info.address:
  1547. out.add_answer_at_time(DNSAddress(
  1548. info.server, _TYPE_A, _CLASS_IN, 0,
  1549. info.address), 0)
  1550. self.send(out)
  1551. i += 1
  1552. next_time += _UNREGISTER_TIME
  1553. def check_service(self, info: ServiceInfo, allow_name_change: bool) -> None:
  1554. """Checks the network for a unique service name, modifying the
  1555. ServiceInfo passed in if it is not unique."""
  1556. # This is kind of funky because of the subtype based tests
  1557. # need to make subtypes a first class citizen
  1558. service_name = service_type_name(info.name)
  1559. if not info.type.endswith(service_name):
  1560. raise BadTypeInNameException
  1561. instance_name = info.name[:-len(service_name) - 1]
  1562. next_instance_number = 2
  1563. now = current_time_millis()
  1564. next_time = now
  1565. i = 0
  1566. while i < 3:
  1567. # check for a name conflict
  1568. while self.cache.current_entry_with_name_and_alias(
  1569. info.type, info.name):
  1570. if not allow_name_change:
  1571. raise NonUniqueNameException
  1572. # change the name and look for a conflict
  1573. info.name = '%s-%s.%s' % (
  1574. instance_name, next_instance_number, info.type)
  1575. next_instance_number += 1
  1576. service_type_name(info.name)
  1577. next_time = now
  1578. i = 0
  1579. if now < next_time:
  1580. self.wait(next_time - now)
  1581. now = current_time_millis()
  1582. continue
  1583. out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
  1584. self.debug = out
  1585. out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN))
  1586. out.add_authorative_answer(DNSPointer(
  1587. info.type, _TYPE_PTR, _CLASS_IN, info.ttl, info.name))
  1588. self.send(out)
  1589. i += 1
  1590. next_time += _CHECK_TIME
  1591. def add_listener(self, listener: RecordUpdateListener, question: Optional[DNSQuestion]) -> None:
  1592. """Adds a listener for a given question. The listener will have
  1593. its update_record method called when information is available to
  1594. answer the question."""
  1595. now = current_time_millis()
  1596. self.listeners.append(listener)
  1597. if question is not None:
  1598. for record in self.cache.entries_with_name(question.name):
  1599. if question.answered_by(record) and not record.is_expired(now):
  1600. listener.update_record(self, now, record)
  1601. self.notify_all()
  1602. def remove_listener(self, listener: RecordUpdateListener) -> None:
  1603. """Removes a listener."""
  1604. try:
  1605. self.listeners.remove(listener)
  1606. self.notify_all()
  1607. except Exception as e: # TODO stop catching all Exceptions
  1608. log.exception('Unknown error, possibly benign: %r', e)
  1609. def update_record(self, now: float, rec: DNSRecord) -> None:
  1610. """Used to notify listeners of new information that has updated
  1611. a record."""
  1612. for listener in self.listeners:
  1613. listener.update_record(self, now, rec)
  1614. self.notify_all()
  1615. def handle_response(self, msg: DNSIncoming) -> None:
  1616. """Deal with incoming response packets. All answers
  1617. are held in the cache, and listeners are notified."""
  1618. now = current_time_millis()
  1619. for record in msg.answers:
  1620. expired = record.is_expired(now)
  1621. if record in self.cache.entries():
  1622. if expired:
  1623. self.cache.remove(record)
  1624. else:
  1625. entry = self.cache.get(record)
  1626. if entry is not None:
  1627. entry.reset_ttl(record)
  1628. else:
  1629. self.cache.add(record)
  1630. for record in msg.answers:
  1631. self.update_record(now, record)
  1632. def handle_query(self, msg: DNSIncoming, addr: str, port: int) -> None:
  1633. """Deal with incoming query packets. Provides a response if
  1634. possible."""
  1635. out = None
  1636. # Support unicast client responses
  1637. #
  1638. if port != _MDNS_PORT:
  1639. out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False)
  1640. for question in msg.questions:
  1641. out.add_question(question)
  1642. for question in msg.questions:
  1643. if question.type == _TYPE_PTR:
  1644. if question.name == "_services._dns-sd._udp.local.":
  1645. for stype in self.servicetypes.keys():
  1646. if out is None:
  1647. out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
  1648. out.add_answer(msg, DNSPointer(
  1649. "_services._dns-sd._udp.local.", _TYPE_PTR,
  1650. _CLASS_IN, _DNS_TTL, stype))
  1651. for service in self.services.values():
  1652. if question.name == service.type:
  1653. if out is None:
  1654. out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
  1655. out.add_answer(msg, DNSPointer(
  1656. service.type, _TYPE_PTR,
  1657. _CLASS_IN, service.ttl, service.name))
  1658. else:
  1659. try:
  1660. if out is None:
  1661. out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
  1662. # Answer A record queries for any service addresses we know
  1663. if question.type in (_TYPE_A, _TYPE_ANY):
  1664. for service in self.services.values():
  1665. if service.server == question.name.lower():
  1666. out.add_answer(msg, DNSAddress(
  1667. question.name, _TYPE_A,
  1668. _CLASS_IN | _CLASS_UNIQUE,
  1669. service.ttl, service.address))
  1670. name_to_find = question.name.lower()
  1671. if name_to_find not in self.services:
  1672. continue
  1673. service = self.services[name_to_find]
  1674. if question.type in (_TYPE_SRV, _TYPE_ANY):
  1675. out.add_answer(msg, DNSService(
  1676. question.name, _TYPE_SRV, _CLASS_IN | _CLASS_UNIQUE,
  1677. service.ttl, service.priority, service.weight,
  1678. service.port, service.server))
  1679. if question.type in (_TYPE_TXT, _TYPE_ANY):
  1680. out.add_answer(msg, DNSText(
  1681. question.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE,
  1682. service.ttl, service.text))
  1683. if question.type == _TYPE_SRV:
  1684. out.add_additional_answer(DNSAddress(
  1685. service.server, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE,
  1686. service.ttl, service.address))
  1687. except Exception: # TODO stop catching all Exceptions
  1688. self.log_exception_warning()
  1689. if out is not None and out.answers:
  1690. out.id = msg.id
  1691. self.send(out, addr, port)
  1692. def send(self, out: DNSOutgoing, addr: str = _MDNS_ADDR, port: int = _MDNS_PORT) -> None:
  1693. """Sends an outgoing packet."""
  1694. packet = out.packet()
  1695. if len(packet) > _MAX_MSG_ABSOLUTE:
  1696. self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r",
  1697. out, len(packet), packet)
  1698. return
  1699. log.debug('Sending %r (%d bytes) as %r...', out, len(packet), packet)
  1700. for s in self._respond_sockets:
  1701. if self._GLOBAL_DONE:
  1702. return
  1703. try:
  1704. bytes_sent = s.sendto(packet, 0, (addr, port))
  1705. except Exception: # TODO stop catching all Exceptions
  1706. # on send errors, log the exception and keep going
  1707. self.log_exception_warning()
  1708. else:
  1709. if bytes_sent != len(packet):
  1710. self.log_warning_once(
  1711. '!!! sent %d out of %d bytes to %r' % (
  1712. bytes_sent, len(packet), s))
  1713. def close(self) -> None:
  1714. """Ends the background threads, and prevent this instance from
  1715. servicing further queries."""
  1716. if not self._GLOBAL_DONE:
  1717. self._GLOBAL_DONE = True
  1718. # remove service listeners
  1719. self.remove_all_service_listeners()
  1720. self.unregister_all_services()
  1721. # shutdown recv socket and thread
  1722. if not self.unicast:
  1723. self.engine.del_reader(self._listen_socket)
  1724. self._listen_socket.close()
  1725. else:
  1726. for s in self._respond_sockets:
  1727. self.engine.del_reader(s)
  1728. self.engine.join()
  1729. # shutdown the rest
  1730. self.notify_all()
  1731. self.reaper.join()
  1732. for s in self._respond_sockets:
  1733. s.close()