You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
1.9 KiB

4 years ago
  1. import asyncio
  2. from .log import internal_logger
  3. class BaseProtocol(asyncio.Protocol):
  4. __slots__ = ('_loop', '_paused', '_drain_waiter',
  5. '_connection_lost', 'transport')
  6. def __init__(self, loop=None):
  7. if loop is None:
  8. self._loop = asyncio.get_event_loop()
  9. else:
  10. self._loop = loop
  11. self._paused = False
  12. self._drain_waiter = None
  13. self._connection_lost = False
  14. self.transport = None
  15. def pause_writing(self):
  16. assert not self._paused
  17. self._paused = True
  18. if self._loop.get_debug():
  19. internal_logger.debug("%r pauses writing", self)
  20. def resume_writing(self):
  21. assert self._paused
  22. self._paused = False
  23. if self._loop.get_debug():
  24. internal_logger.debug("%r resumes writing", self)
  25. waiter = self._drain_waiter
  26. if waiter is not None:
  27. self._drain_waiter = None
  28. if not waiter.done():
  29. waiter.set_result(None)
  30. def connection_made(self, transport):
  31. self.transport = transport
  32. def connection_lost(self, exc):
  33. self._connection_lost = True
  34. # Wake up the writer if currently paused.
  35. self.transport = None
  36. if not self._paused:
  37. return
  38. waiter = self._drain_waiter
  39. if waiter is None:
  40. return
  41. self._drain_waiter = None
  42. if waiter.done():
  43. return
  44. if exc is None:
  45. waiter.set_result(None)
  46. else:
  47. waiter.set_exception(exc)
  48. async def _drain_helper(self):
  49. if self._connection_lost:
  50. raise ConnectionResetError('Connection lost')
  51. if not self._paused:
  52. return
  53. waiter = self._drain_waiter
  54. assert waiter is None or waiter.cancelled()
  55. waiter = self._loop.create_future()
  56. self._drain_waiter = waiter
  57. await waiter