You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

543 lines
20 KiB

4 years ago
  1. import re
  2. import datetime
  3. from decimal import Decimal
  4. from base64 import b64decode
  5. from binascii import unhexlify
  6. from functools import partial
  7. from collections import OrderedDict
  8. import six
  9. import requests
  10. from requests.compat import urljoin, urlparse
  11. from dateutil.parser import parse as parse_date
  12. from lxml import etree
  13. from .util import _getLogger
  14. from .const import HTTP_TIMEOUT
  15. from .soap import SOAP
  16. from .marshal import marshal_value
  17. class UPNPError(Exception):
  18. """
  19. Exception class for UPnP errors.
  20. """
  21. pass
  22. class InvalidActionException(UPNPError):
  23. """
  24. Action doesn't exist.
  25. """
  26. pass
  27. class ValidationError(UPNPError):
  28. """
  29. Given value didn't validate with the given data type.
  30. """
  31. def __init__(self, reasons):
  32. super(ValidationError, self).__init__()
  33. self.reasons = reasons
  34. class UnexpectedResponse(UPNPError):
  35. """
  36. Got a response we didn't expect.
  37. """
  38. pass
  39. class CallActionMixin(object):
  40. def __call__(self, action_name, **kwargs):
  41. """
  42. Convenience method for quickly finding and calling an Action on a
  43. Service. Must have implemented a `find_action(action_name)` method.
  44. """
  45. action = self.find_action(action_name)
  46. if action is not None:
  47. return action(**kwargs)
  48. raise InvalidActionException('Action with name %r does not exist.' % action_name)
  49. class Device(CallActionMixin):
  50. """
  51. UPNP Device represention.
  52. This class represents an UPnP device. `location` is an URL to a control XML
  53. file, per UPnP standard section 2.3 ('Device Description'). This MUST match
  54. the URL as given in the 'Location' header when using discovery (SSDP).
  55. `device_name` is a name for the device, which may be obtained using the
  56. SSDP class or may be made up by the caller.
  57. Raises urllib2.HTTPError when the location is invalid
  58. Example:
  59. >>> device = Device('http://192.168.1.254:80/upnp/IGD.xml')
  60. >>> for service in device.services:
  61. ... print service.service_id
  62. ...
  63. urn:upnp-org:serviceId:layer3f
  64. urn:upnp-org:serviceId:wancic
  65. urn:upnp-org:serviceId:wandsllc:pvc_Internet
  66. urn:upnp-org:serviceId:wanipc:Internet
  67. """
  68. def __init__(
  69. self, location, device_name=None, ignore_urlbase=False,
  70. http_auth=None, http_headers=None):
  71. """
  72. Create a new Device instance. `location` is an URL to an XML file
  73. describing the server's services.
  74. """
  75. self.location = location
  76. self.device_name = location if device_name is None else device_name
  77. self.services = []
  78. self.service_map = {}
  79. self._log = _getLogger('Device')
  80. self.http_auth = http_auth
  81. self.http_headers = http_headers
  82. resp = requests.get(
  83. location,
  84. timeout=HTTP_TIMEOUT,
  85. auth=self.http_auth,
  86. headers=self.http_headers
  87. )
  88. resp.raise_for_status()
  89. root = etree.fromstring(resp.content)
  90. findtext = partial(root.findtext, namespaces=root.nsmap)
  91. self.device_type = findtext('device/deviceType')
  92. self.friendly_name = findtext('device/friendlyName')
  93. self.manufacturer = findtext('device/manufacturer')
  94. self.manufacturer_url = findtext('device/manufacturerURL')
  95. self.model_description = findtext('device/modelDescription')
  96. self.model_name = findtext('device/modelName')
  97. self.model_number = findtext('device/modelNumber')
  98. self.serial_number = findtext('device/serialNumber')
  99. self.udn = findtext('device/UDN')
  100. self._url_base = findtext('URLBase')
  101. if self._url_base is None or ignore_urlbase:
  102. # If no URL Base is given, the UPnP specification says: "the base
  103. # URL is the URL from which the device description was retrieved"
  104. self._url_base = self.location
  105. self._root_xml = root
  106. self._findtext = findtext
  107. self._find = partial(root.find, namespaces=root.nsmap)
  108. self._findall = partial(root.findall, namespaces=root.nsmap)
  109. self._read_services()
  110. def __repr__(self):
  111. return "<Device '%s'>" % (self.friendly_name)
  112. def __getattr__(self, name):
  113. """
  114. Allow Services to be returned as members of the Device.
  115. """
  116. try:
  117. return self.service_map[name]
  118. except KeyError:
  119. raise AttributeError('No attribute or service found with name %r.' % name)
  120. def __getitem__(self, key):
  121. """
  122. Allow Services to be returned as dictionary keys of the Device.
  123. """
  124. return self.service_map[key]
  125. def __dir__(self):
  126. """
  127. Add Service names to `dir(device)` output for use with tab-completion in repl.
  128. """
  129. return super(Device, self).__dir__() + list(self.service_map.keys())
  130. @property
  131. def actions(self):
  132. actions = []
  133. for service in self.services:
  134. actions.extend(service.actions)
  135. return actions
  136. def _read_services(self):
  137. """
  138. Read the control XML file and populate self.services with a list of
  139. services in the form of Service class instances.
  140. """
  141. # The double slash in the XPath is deliberate, as services can be
  142. # listed in two places (Section 2.3 of uPNP device architecture v1.1)
  143. for node in self._findall('device//serviceList/service'):
  144. findtext = partial(node.findtext, namespaces=self._root_xml.nsmap)
  145. svc = Service(
  146. self,
  147. self._url_base,
  148. findtext('serviceType'),
  149. findtext('serviceId'),
  150. findtext('controlURL'),
  151. findtext('SCPDURL'),
  152. findtext('eventSubURL')
  153. )
  154. self._log.debug(
  155. '%s: Service %r at %r', self.device_name, svc.service_type, svc.scpd_url)
  156. self.services.append(svc)
  157. self.service_map[svc.name] = svc
  158. def find_action(self, action_name):
  159. """Find an action by name.
  160. Convenience method that searches through all the services offered by
  161. the Server for an action and returns an Action instance. If the action
  162. is not found, returns None. If multiple actions with the same name are
  163. found it returns the first one.
  164. """
  165. for service in self.services:
  166. action = service.find_action(action_name)
  167. if action is not None:
  168. return action
  169. class Service(CallActionMixin):
  170. """
  171. Service Control Point Definition. This class reads an SCPD XML file and
  172. parses the actions and state variables. It can then be used to call
  173. actions.
  174. """
  175. def __init__(self, device, url_base, service_type, service_id,
  176. control_url, scpd_url, event_sub_url):
  177. self.device = device
  178. self._url_base = url_base
  179. self.service_type = service_type
  180. self.service_id = service_id
  181. self._control_url = control_url
  182. self.scpd_url = scpd_url
  183. self._event_sub_url = event_sub_url
  184. self.actions = []
  185. self.action_map = {}
  186. self.statevars = {}
  187. self._log = _getLogger('Service')
  188. self._log.debug('%s url_base: %s', self.service_id, self._url_base)
  189. self._log.debug('%s SCPDURL: %s', self.service_id, self.scpd_url)
  190. self._log.debug('%s controlURL: %s', self.service_id, self._control_url)
  191. self._log.debug('%s eventSubURL: %s', self.service_id, self._event_sub_url)
  192. url = urljoin(self._url_base, self.scpd_url)
  193. self._log.debug('Reading %s', url)
  194. resp = requests.get(
  195. url,
  196. timeout=HTTP_TIMEOUT,
  197. auth=self.device.http_auth,
  198. headers=self.device.http_headers
  199. )
  200. resp.raise_for_status()
  201. self.scpd_xml = etree.fromstring(resp.content)
  202. self._find = partial(self.scpd_xml.find, namespaces=self.scpd_xml.nsmap)
  203. self._findtext = partial(self.scpd_xml.findtext, namespaces=self.scpd_xml.nsmap)
  204. self._findall = partial(self.scpd_xml.findall, namespaces=self.scpd_xml.nsmap)
  205. self._read_state_vars()
  206. self._read_actions()
  207. def __repr__(self):
  208. return "<Service service_id='%s'>" % (self.service_id)
  209. def __getattr__(self, name):
  210. """
  211. Allow Actions to be returned as members of the Service.
  212. """
  213. try:
  214. return self.action_map[name]
  215. except KeyError:
  216. raise AttributeError('No attribute or action found with name %r.' % name)
  217. def __getitem__(self, key):
  218. """
  219. Allow Actions to be returned as dictionary keys of the Service.
  220. """
  221. return self.action_map[key]
  222. def __dir__(self):
  223. """
  224. Add Action names to `dir(service)` output for use with tab-completion in repl.
  225. """
  226. return super(Service, self).__dir__() + [a.name for a in self.actions]
  227. @property
  228. def name(self):
  229. try:
  230. return self.service_id[self.service_id.rindex(":")+1:]
  231. except ValueError:
  232. return self.service_id
  233. def _read_state_vars(self):
  234. for statevar_node in self._findall('serviceStateTable/stateVariable'):
  235. findtext = partial(statevar_node.findtext, namespaces=statevar_node.nsmap)
  236. findall = partial(statevar_node.findall, namespaces=statevar_node.nsmap)
  237. name = findtext('name')
  238. datatype = findtext('dataType')
  239. send_events = statevar_node.attrib.get('sendEvents', 'yes').lower() == 'yes'
  240. allowed_values = set([e.text for e in findall('allowedValueList/allowedValue')])
  241. self.statevars[name] = dict(
  242. name=name,
  243. datatype=datatype,
  244. allowed_values=allowed_values,
  245. send_events=send_events
  246. )
  247. def _read_actions(self):
  248. action_url = urljoin(self._url_base, self._control_url)
  249. for action_node in self._findall('actionList/action'):
  250. name = action_node.findtext('name', namespaces=action_node.nsmap)
  251. argsdef_in = []
  252. argsdef_out = []
  253. for arg_node in action_node.findall(
  254. 'argumentList/argument', namespaces=action_node.nsmap):
  255. findtext = partial(arg_node.findtext, namespaces=arg_node.nsmap)
  256. arg_name = findtext('name')
  257. arg_statevar = self.statevars[findtext('relatedStateVariable')]
  258. if findtext('direction').lower() == 'in':
  259. argsdef_in.append((arg_name, arg_statevar))
  260. else:
  261. argsdef_out.append((arg_name, arg_statevar))
  262. action = Action(self, action_url, self.service_type, name, argsdef_in, argsdef_out)
  263. self.action_map[name] = action
  264. self.actions.append(action)
  265. @staticmethod
  266. def validate_subscription_response(resp):
  267. lc_headers = {k.lower(): v for k, v in resp.headers.items()}
  268. try:
  269. sid = lc_headers['sid']
  270. except KeyError:
  271. raise UnexpectedResponse('Event subscription call returned without a "SID" header')
  272. try:
  273. timeout_str = lc_headers['timeout'].lower()
  274. except KeyError:
  275. raise UnexpectedResponse('Event subscription call returned without a "Timeout" header')
  276. if not timeout_str.startswith('second-'):
  277. raise UnexpectedResponse(
  278. 'Event subscription call returned an invalid timeout value: %r' % timeout_str)
  279. timeout_str = timeout_str[len('Second-'):]
  280. try:
  281. timeout = None if timeout_str == 'infinite' else int(timeout_str)
  282. except ValueError:
  283. raise UnexpectedResponse(
  284. 'Event subscription call returned a timeout value which wasn\'t "infinite" or an in'
  285. 'teger')
  286. return sid, timeout
  287. @staticmethod
  288. def validate_subscription_renewal_response(resp):
  289. lc_headers = {k.lower(): v for k, v in resp.headers.items()}
  290. try:
  291. timeout_str = lc_headers['timeout'].lower()
  292. except KeyError:
  293. raise UnexpectedResponse('Event subscription call returned without a "Timeout" header')
  294. if not timeout_str.startswith('second-'):
  295. raise UnexpectedResponse(
  296. 'Event subscription call returned an invalid timeout value: %r' % timeout_str)
  297. timeout_str = timeout_str[len('Second-'):]
  298. try:
  299. timeout = None if timeout_str == 'infinite' else int(timeout_str)
  300. except ValueError:
  301. raise UnexpectedResponse(
  302. 'Event subscription call returned a timeout value which wasn\'t "infinite" or an in'
  303. 'teger')
  304. return timeout
  305. def find_action(self, action_name):
  306. try:
  307. return self.action_map[action_name]
  308. except KeyError:
  309. pass
  310. def subscribe(self, callback_url, timeout=None):
  311. """
  312. Set up a subscription to the events offered by this service.
  313. """
  314. url = urljoin(self._url_base, self._event_sub_url)
  315. headers = dict(
  316. HOST=urlparse(url).netloc,
  317. CALLBACK='<%s>' % callback_url,
  318. NT='upnp:event'
  319. )
  320. if timeout is not None:
  321. headers['TIMEOUT'] = 'Second-%s' % timeout
  322. resp = requests.request('SUBSCRIBE', url, headers=headers, auth=self.device.http_auth)
  323. resp.raise_for_status()
  324. return Service.validate_subscription_response(resp)
  325. def renew_subscription(self, sid, timeout=None):
  326. """
  327. Renews a previously configured subscription.
  328. """
  329. url = urljoin(self._url_base, self._event_sub_url)
  330. headers = dict(
  331. HOST=urlparse(url).netloc,
  332. SID=sid
  333. )
  334. if timeout is not None:
  335. headers['TIMEOUT'] = 'Second-%s' % timeout
  336. resp = requests.request('SUBSCRIBE', url, headers=headers, auth=self.device.http_auth)
  337. resp.raise_for_status()
  338. return Service.validate_subscription_renewal_response(resp)
  339. def cancel_subscription(self, sid):
  340. """
  341. Unsubscribes from a previously configured subscription.
  342. """
  343. url = urljoin(self._url_base, self._event_sub_url)
  344. headers = dict(
  345. HOST=urlparse(url).netloc,
  346. SID=sid
  347. )
  348. resp = requests.request('UNSUBSCRIBE', url, headers=headers, auth=self.device.http_auth)
  349. resp.raise_for_status()
  350. class Action(object):
  351. def __init__(self, service, url, service_type, name, argsdef_in=None, argsdef_out=None):
  352. if argsdef_in is None:
  353. argsdef_in = []
  354. if argsdef_out is None:
  355. argsdef_out = []
  356. self.service = service
  357. self.url = url
  358. self.service_type = service_type
  359. self.name = name
  360. self.argsdef_in = argsdef_in
  361. self.argsdef_out = argsdef_out
  362. self._log = _getLogger('Action')
  363. def __repr__(self):
  364. return "<Action '%s'>" % (self.name)
  365. def __call__(self, http_auth=None, http_headers=None, **kwargs):
  366. arg_reasons = {}
  367. call_kwargs = OrderedDict()
  368. # Validate arguments using the SCPD stateVariable definitions
  369. for name, statevar in self.argsdef_in:
  370. if name not in kwargs:
  371. raise UPNPError('Missing required param \'%s\'' % (name))
  372. valid, reasons = self.validate_arg(kwargs[name], statevar)
  373. if not valid:
  374. arg_reasons[name] = reasons
  375. # Preserve the order of call args, as listed in SCPD XML spec
  376. call_kwargs[name] = kwargs[name]
  377. if arg_reasons:
  378. raise ValidationError(arg_reasons)
  379. # Make the actual call
  380. self._log.debug(">> %s (%s)", self.name, call_kwargs)
  381. soap_client = SOAP(self.url, self.service_type)
  382. soap_response = soap_client.call(
  383. self.name,
  384. call_kwargs,
  385. http_auth or self.service.device.http_auth,
  386. http_headers or self.service.device.http_headers
  387. )
  388. self._log.debug("<< %s (%s): %s", self.name, call_kwargs, soap_response)
  389. # Marshall the response to python data types
  390. out = {}
  391. for name, statevar in self.argsdef_out:
  392. _, value = marshal_value(statevar['datatype'], soap_response[name])
  393. out[name] = value
  394. return out
  395. @staticmethod
  396. def validate_arg(arg, argdef):
  397. """
  398. Validate an incoming (unicode) string argument according the UPnP spec. Raises UPNPError.
  399. """
  400. datatype = argdef['datatype']
  401. reasons = set()
  402. ranges = {
  403. 'ui1': (int, 0, 255),
  404. 'ui2': (int, 0, 65535),
  405. 'ui4': (int, 0, 4294967295),
  406. 'i1': (int, -128, 127),
  407. 'i2': (int, -32768, 32767),
  408. 'i4': (int, -2147483648, 2147483647),
  409. 'r4': (Decimal, Decimal('3.40282347E+38'), Decimal('1.17549435E-38'))
  410. }
  411. try:
  412. if datatype in set(ranges.keys()):
  413. v_type, v_min, v_max = ranges[datatype]
  414. if not v_min <= v_type(arg) <= v_max:
  415. reasons.add('%r datatype must be a number in the range %s to %s' % (
  416. datatype, v_min, v_max))
  417. elif datatype in {'r8', 'number', 'float', 'fixed.14.4'}:
  418. v = Decimal(arg)
  419. if v < 0:
  420. assert Decimal('-1.79769313486232E308') <= v <= Decimal('4.94065645841247E-324')
  421. else:
  422. assert Decimal('4.94065645841247E-324') <= v <= Decimal('1.79769313486232E308')
  423. elif datatype == 'char':
  424. v = arg.decode('utf8') if six.PY2 or isinstance(arg, bytes) else arg
  425. assert len(v) == 1
  426. elif datatype == 'string':
  427. v = arg.decode("utf8") if six.PY2 or isinstance(arg, bytes) else arg
  428. if argdef['allowed_values'] and v not in argdef['allowed_values']:
  429. reasons.add('Value %r not in allowed values list' % arg)
  430. elif datatype == 'date':
  431. v = parse_date(arg)
  432. if any((v.hour, v.minute, v.second)):
  433. reasons.add("'date' datatype must not contain a time")
  434. elif datatype in ('dateTime', 'dateTime.tz'):
  435. v = parse_date(arg)
  436. if datatype == 'dateTime' and v.tzinfo is not None:
  437. reasons.add("'dateTime' datatype must not contain a timezone")
  438. elif datatype in ('time', 'time.tz'):
  439. now = datetime.datetime.utcnow()
  440. v = parse_date(arg, default=now)
  441. if v.tzinfo is not None:
  442. now += v.utcoffset()
  443. if not all((
  444. v.day == now.day,
  445. v.month == now.month,
  446. v.year == now.year)):
  447. reasons.add('%r datatype must not contain a date' % datatype)
  448. if datatype == 'time' and v.tzinfo is not None:
  449. reasons.add('%r datatype must not have timezone information' % datatype)
  450. elif datatype == 'boolean':
  451. valid = {'true', 'yes', '1', 'false', 'no', '0'}
  452. if arg.lower() not in valid:
  453. reasons.add('%r datatype must be one of %s' % (datatype, ','.join(valid)))
  454. elif datatype == 'bin.base64':
  455. b64decode(arg)
  456. elif datatype == 'bin.hex':
  457. unhexlify(arg)
  458. elif datatype == 'uri':
  459. urlparse(arg)
  460. elif datatype == 'uuid':
  461. if not re.match(
  462. r'^[0-9a-f]{8}\-[0-9a-f]{4}\-[0-9a-f]{4}\-[0-9a-f]{4}\-[0-9a-f]{12}$',
  463. arg, re.I):
  464. reasons.add('%r datatype must contain a valid UUID')
  465. else:
  466. reasons.add("%r datatype is unrecognised." % datatype)
  467. except ValueError as exc:
  468. reasons.add(str(exc))
  469. return not bool(len(reasons)), reasons