import re
import json
import socket
import asyncio
import logging
from time import time
import websockets
from .util import Queue, get as default_get, current_task
from .error import (
SocketIOError, ConnectionFailed,
ConnectionClosed, PingTimeout
)
from .proxy import ProxyError
[docs]class SocketIOResponse:
"""socket.io event response.
Attributes
----------
id : `int`
match : `function`(`str`, `object`)
future : `asyncio.Future`
"""
MAX_ID = 2 ** 32
last_id = 0
[docs] def __init__(self, match):
self.id = (self.last_id + 1) % self.MAX_ID
self.last_id = self.id
self.match = match
self.future = asyncio.Future()
def __eq__(self, res):
if isinstance(res, SocketIOResponse):
return self is res
return self.id == res
def __str__(self):
return '<SocketIOResponse #%d>' % self.id
__repr__ = __str__
[docs] def set(self, value):
self.future.set_result(value)
[docs] def cancel(self, ex=None):
if not self.future.done():
if ex is None:
self.future.cancel()
else:
self.future.set_exception(ex)
[docs] @staticmethod
def match_event(ev=None, data=None):
def match(ev_, data_):
if not re.match(ev, ev_):
return False
if data is not None:
if isinstance(data, dict):
if not isinstance(data_, dict):
return False
for key, value in data.items():
if value != data_.get(key):
return False
else:
raise NotImplementedError('match_event !isinstance(data, dict)')
return True
return match
[docs]class SocketIO:
"""Asynchronous socket.io connection.
Attributes
----------
websocket : `websockets.client.WebSocketClientProtocol`
Websocket connection.
ping_interval : `float`
Ping interval in seconds.
ping_timeout : `float`
Ping timeout in seconds.
error : `None` or `Exception`
events : `asyncio.Queue` of ((`str`, `object`) or `None`)
Event queue.
response : `list` of `cytube_bot.socket_io.SocketIOResponse`
response_lock : `asyncio.Lock`
ping_task : `asyncio.tasks.Task`
recv_task : `asyncio.tasks.Task`
close_task : `asyncio.tasks.Task`
closing : `asyncio.Event`
closed : `asyncio.Event`
ping_response : `asyncio.Event`
loop : `asyncio.events.AbstractEventLoop`
Event loop.
"""
logger = logging.getLogger(__name__)
[docs] def __init__(self, websocket, config, qsize, loop):
"""
Parameters
----------
websocket : `websockets.client.WebSocketClientProtocol`
Websocket connection.
config : `dict`
Websocket configuration.
qsize : `int`
Event queue size.
loop : `asyncio.events.AbstractEventLoop`
Event loop.
"""
self.websocket = websocket
self.loop = loop
self._error = None
self.closing = asyncio.Event(loop=self.loop)
self.closed = asyncio.Event(loop=self.loop)
self.ping_response = asyncio.Event(loop=self.loop)
self.events = Queue(maxsize=qsize, loop=self.loop)
self.response = []
self.response_lock = asyncio.Lock()
self.ping_interval = max(1, config.get('pingInterval', 10000) / 1000)
self.ping_timeout = max(1, config.get('pingTimeout', 10000) / 1000)
self.ping_task = self.loop.create_task(self._ping())
self.recv_task = self.loop.create_task(self._recv())
self.close_task = None
@property
def error(self):
return self._error
@error.setter
def error(self, ex):
if self._error is not None:
self.logger.info('error already set: %r', self._error)
return
self.logger.info('set error %r', ex)
self._error = ex
if ex is not None:
self.logger.info('create close task')
self.close_task = self.loop.create_task(self.close())
[docs] @asyncio.coroutine
def close(self):
"""Close the connection.
"""
self.logger.info('close')
if self.close_task is not None:
if self.close_task is current_task(self.loop):
self.logger.info('current task is close task')
else:
self.logger.info('wait for close task')
yield from asyncio.wait_for(self.close_task,
None, loop=self.loop)
if self.closed.is_set():
self.logger.info('already closed')
return
if self.closing.is_set():
self.logger.info('already closing, wait')
yield from self.closed.wait()
return
self.closing.set()
try:
if self._error is None:
self.logger.info('set error')
self._error = ConnectionClosed()
else:
self.logger.info('error already set: %r', self._error)
self.logger.info('queue null event')
try:
self.events.put_nowait(None)
except asyncio.QueueFull:
pass
self.logger.info('set response future exception')
for res in self.response:
res.cancel(self.error)
self.response = []
self.logger.info('cancel ping task')
self.ping_task.cancel()
self.logger.info('cancel recv task')
self.recv_task.cancel()
self.logger.info('wait for tasks')
yield from asyncio.wait_for(
asyncio.gather(self.ping_task, self.recv_task),
None, loop=self.loop
)
self.ping_response.clear()
self.logger.info('close websocket')
yield from self.websocket.close()
self.logger.info('clear event queue')
while not self.events.empty():
ev = yield from self.events.get()
self.events.task_done()
if isinstance(ev, Exception):
self.error = ev
#yield from self.events.join()
finally:
self.ping_task = None
self.recv_task = None
self.websocket = None
self.closed.set()
[docs] @asyncio.coroutine
def recv(self):
"""Receive an event.
Returns
-------
(`str`, `object`)
Event name and data.
Raises
------
`ConnectionClosed`
"""
if self.error is not None:
raise self.error # pylint:disable=raising-bad-type
ev = yield from self.events.get()
self.events.task_done()
if ev is None:
raise self.error # pylint:disable=raising-bad-type
return ev
[docs] @asyncio.coroutine
def emit(self, event, data, match_response=False, response_timeout=None):
"""Send an event.
Parameters
----------
event : `str`
Event name.
data : `object`
Event data.
match_response : `function` or `None`, optional
Response match function.
response_timeout : `float` or `None`, optional
Response timeout in seconds.
Returns
-------
`object`
Response data if `get_response` is `True`.
Raises
------
`asyncio.CancelledError`
`SocketIOError`
"""
if self.error is not None:
raise self.error # pylint:disable=raising-bad-type
data = '42%s' % json.dumps((event, data))
self.logger.info('emit %s', data)
release = False
response = None
try:
if match_response is not None:
yield from self.response_lock.acquire()
release = True
response = SocketIOResponse(match_response)
self.logger.info('get response %s', response)
self.response.append(response)
yield from self.websocket.send(data)
if match_response is not None:
self.response_lock.release()
release = False
if response_timeout is not None:
res = asyncio.wait_for(response.future,
response_timeout,
loop=self.loop)
else:
res = response.future
try:
res = yield from res
self.logger.info('%s', res)
except asyncio.CancelledError:
self.logger.info('response cancelled %s', event)
raise
except asyncio.TimeoutError as ex:
self.logger.info('response timeout %s', event)
response.cancel()
res = None
finally:
yield from self.response_lock.acquire()
try:
self.response.remove(response)
except ValueError:
pass
finally:
self.response_lock.release()
self.logger.info('response %s %r', event, res)
return res
except asyncio.CancelledError:
self.logger.error('emit cancelled')
raise
except Exception as ex:
self.logger.error('emit error: %r', ex)
if not isinstance(ex, SocketIOError):
ex = SocketIOError(ex)
raise ex
finally:
if release:
self.response_lock.release()
@asyncio.coroutine
def _ping(self):
"""Ping task."""
try:
dt = 0
while self.error is None:
yield from asyncio.sleep(max(self.ping_interval - dt, 0))
self.logger.debug('ping')
self.ping_response.clear()
dt = time()
yield from self.websocket.send('2')
yield from asyncio.wait_for(
self.ping_response.wait(),
self.ping_timeout,
loop=self.loop
)
dt = max(time() - dt, 0)
except asyncio.CancelledError:
self.logger.info('ping cancelled')
except asyncio.TimeoutError:
self.logger.error('ping timeout')
self.error = PingTimeout()
except (socket.error,
ProxyError,
websockets.exceptions.ConnectionClosed,
websockets.exceptions.InvalidState,
websockets.exceptions.PayloadTooBig,
websockets.exceptions.WebSocketProtocolError
) as ex:
self.logger.error('ping error: %r', ex)
self.error = ConnectionClosed(ex)
@asyncio.coroutine
def _recv(self):
"""Read task."""
try:
while self.error is None:
data = yield from self.websocket.recv()
self.logger.debug('recv %s', data)
if data.startswith('2'):
data = data[1:]
self.logger.debug('ping %s', data)
yield from self.websocket.send('3' + data)
elif data.startswith('3'):
self.logger.debug('pong %s', data[1:])
self.ping_response.set()
elif data.startswith('4'):
try:
if data[1] == '0':
event = ''
data = None
elif data[1] == '1':
event = data[2:]
data = None
else:
data = json.loads(data[2:])
if not isinstance(data, list):
raise ValueError('not an array')
if len(data) == 0:
raise ValueError('empty array')
if len(data) == 1:
event, data = data[0], None
elif len(data) == 2:
event, data = data
else:
event = data[0]
data = data[1:]
except ValueError as ex:
self.logger.error('invalid event %s: %r', data, ex)
else:
self.logger.debug('event %s %s', event, data)
yield from self.events.put((event, data))
for response in self.response:
if response.match(event, data):
self.logger.debug('response %s %s', event, data)
response.set((event, data))
break
else:
self.logger.warning('unknown event: "%s"', data)
except asyncio.CancelledError:
self.logger.info('recv cancelled')
self.error = ConnectionClosed()
except (socket.error,
ProxyError,
websockets.exceptions.ConnectionClosed,
websockets.exceptions.InvalidState,
websockets.exceptions.PayloadTooBig,
websockets.exceptions.WebSocketProtocolError
) as ex:
self.logger.error('recv error: %r', ex)
self.error = ConnectionClosed(ex)
except Exception as ex:
self.error = ConnectionClosed(ex)
raise
@classmethod
def _get_config(cls, url, loop, get):
"""Get socket configuration.
Parameters
----------
url : `str`
get : `function`
Returns
-------
`dict`
Socket id, ping timeout, ping interval.
"""
url = url + '?EID=2&transport=polling'
cls.logger.info('get %s', url)
data = yield from get(url, loop=loop)
try:
data = json.loads(data[data.index('{'):])
if 'sid' not in data:
raise ValueError('no sid in %s' % data)
except ValueError:
raise websockets.exceptions.InvalidHandshake(data)
return data
@classmethod
@asyncio.coroutine
def _connect(cls, url, qsize, loop, get, connect):
"""Create a connection.
Parameters
----------
url : `str`
qsize : `int`
loop : `asyncio.events.AbstractEventLoop`
get : `function`
connect : `function`
Returns
-------
`SocketIO`
"""
conf = yield from cls._get_config(url, loop, get)
sid = conf['sid']
cls.logger.info('sid=%s', sid)
url = '%s?EID=3&transport=websocket&sid=%s' % (
url.replace('http', 'ws', 1), sid
)
cls.logger.info('connect %s', url)
websocket = yield from connect(url, loop=loop)
try:
cls.logger.info('2probe')
yield from websocket.send('2probe')
res = yield from websocket.recv()
cls.logger.info('3probe')
if res != '3probe':
raise websockets.exceptions.InvalidHandshake(
'invalid response: "%s" != "3probe"',
res
)
cls.logger.info('upgrade')
yield from websocket.send('5')
return SocketIO(websocket, conf, qsize, loop)
except:
yield from websocket.close()
raise
[docs] @classmethod
@asyncio.coroutine
def connect(cls,
url,
retry=-1,
retry_delay=1,
qsize=0,
loop=None,
get=default_get,
connect=websockets.connect):
"""Create a connection.
Parameters
----------
url : `str`
socket.io URL.
retry : `int`
Maximum number of tries.
retry_delay : `float`
Delay between tries in seconds.
qsize : `int`
Event queue size.
loop : `None` or `asyncio.events.AbstractEventLoop`
Event loop.
get : `function`
HTTP GET request coroutine.
connect : `function`
Websocket connect coroutine.
Returns
-------
`SocketIO`
Raises
------
`ConnectionFailed`
`asyncio.CancelledError`
"""
loop = loop or asyncio.get_event_loop()
i = 0
while True:
try:
io = yield from cls._connect(url, qsize, loop, get, connect)
return io
except asyncio.CancelledError:
cls.logger.error(
'connect(%s) (try %d / %d): cancelled',
url, i + 1, retry + 1
)
raise
except Exception as ex:
cls.logger.error(
'connect(%s) (try %d / %d): %r',
url, i + 1, retry + 1, ex
)
if i == retry:
raise ConnectionFailed(ex)
i += 1
yield from asyncio.sleep(retry_delay)