Source code for msl.network.network

"""
Base classes for a :class:`~msl.network.manager.Manager`,
:class:`~msl.network.service.Service` and :class:`~msl.network.client.Client`.
"""
import asyncio
import getpass
import socket
import sys
import threading
import traceback
from typing import Union

from .constants import HOSTNAME
from .constants import LOCALHOST_ALIASES
from .cryptography import get_ssl_context
from .json import deserialize
from .json import serialize
from .utils import _is_manager_regex
from .utils import logger


[docs] class Network(object): def __init__(self): """Base class for the :class:`~msl.network.manager.Manager`, :class:`~msl.network.service.Service` and :class:`~msl.network.client.Client`. """ self._loop = None self._reader = None self._writer = None self._identity = None self._network_name = '<UNKNOWN>' # name[host:port] self._max_debug_length = 256 def __str__(self): return self._network_name
[docs] def identity(self) -> dict: """The identity of a device on the network. All devices on the network must be able to identify themselves to any other device that is connected to the network. There are 3 possible types of network devices -- a :class:`~msl.network.manager.Manager`, a :class:`~msl.network.service.Service` and a :class:`~msl.network.client.Client`. The member names and JSON_ datatype for each network device is described below. .. _JSON: https://www.json.org/ * :class:`~msl.network.manager.Manager` hostname: string The name of the computer that the Network :class:`~msl.network.manager.Manager` is running on. port: integer The port number that the Network :class:`~msl.network.manager.Manager` is running on. attributes: object An object (a Python :class:`dict`) of public attributes that the Network :class:`~msl.network.manager.Manager` provides. Users who are an administrator of the Network :class:`~msl.network.manager.Manager` can request private attributes, see :meth:`~msl.network.client.Client.admin_request`. language: string The programming language that the Network :class:`~msl.network.manager.Manager` is running on. os: string The name of the operating system that the Network :class:`~msl.network.manager.Manager` is running on. clients: object An object (a Python :class:`dict`) of all :class:`~msl.network.client.Client` devices that are currently connected to the Network :class:`~msl.network.manager.Manager`. services: object An object (a Python :class:`dict`) of all :class:`~msl.network.service.Service` devices that are currently connected to the Network :class:`~msl.network.manager.Manager`. * :class:`~msl.network.service.Service` type: string This must be equal to ``'service'`` (case-insensitive). name: string The name to associate with the :class:`~msl.network.service.Service` (can contain spaces). attributes: object An object (a Python :class:`dict`) of the attributes that the :class:`~msl.network.service.Service` provides. The keys are the method names and the values are the method signatures (expressed as a string). The `attributes` get populated automatically when subclassing :class:`~msl.network.service.Service`. If you are creating a `Service` in another programming language then you can use the following as an example for how to define an `attributes` object:: { "pi": "() -> float", "add_integers": "(x:int, y:int) -> int", "scalar_multiply": "(a:float, data:List[floats]) -> List[floats]" } This `Service` would provide a method named ``pi`` that takes no inputs and returns a floating-point number, a method named ``add_integers`` that takes parameters named ``x`` and ``y`` as integer inputs and returns an integer, and a method named ``scalar_multiply`` that takes parameters named ``a`` as a floating-point number and ``data`` as an array of floating-point numbers as inputs and returns an array of floating-point numbers. The key **must** be equal to the name of the method that the `Service` provides; however, the value (the method signature) is only used as a helpful guide to let a :class:`~msl.network.client.Client` know what the method takes as inputs and what the method returns. How you express the method signature is up to you. The above example could also be expressed as:: { "pi": "() -> 3.1415926...", "add_integers": "(int32 x, int32 y) -> x+y", "scalar_multiply": "(double a, *double data) -> *double" } language: string, optional The programming language that the :class:`~msl.network.service.Service` is running on. os: string, optional The name of the operating system that the :class:`~msl.network.service.Service` is running on. max_clients: integer, optional The maximum number of :class:`~msl.network.client.Client`\\s that can be linked with the :class:`~msl.network.service.Service`. If the value is :math:`\\leq` 0 then that means that an unlimited number of :class:`~msl.network.client.Client`\\s can be linked *(this is the default setting if max_clients is not specified)*. * :class:`~msl.network.client.Client` type: string This must be equal to ``'client'`` (case-insensitive). name: string The name to associate with the :class:`~msl.network.client.Client` (can contain spaces). language: string, optional The programming language that the :class:`~msl.network.client.Client` is running on. os: string, optional The name of the operating system that the :class:`~msl.network.client.Client` is running on. Returns ------- :class:`dict` The identity of the network device. """ return self._identity
[docs] @staticmethod def set_logging_level(level: Union[str, int]) -> bool: """Set the :ref:`logging level <levels>`. Parameters ---------- level : :class:`int` or :class:`str` The logging level of the ``msl.network`` logger. Returns ------- :class:`bool` Whether setting the logging level was successful. """ if isinstance(level, str): try: level = int(level) # allow for "20" (as a string) except ValueError: level = level.upper() try: logger.setLevel(level) except (ValueError, TypeError): logger.error('invalid logging level %r', level) return False else: return True
async def _write(self, message, *, writer=None): """Serialize, append the termination and write it to the stream. Parameters ---------- message : :class:`dict` A request or a response. writer : :class:`asyncio.StreamWriter`, optional The writer to use to write the data. If not specified then uses the writer of this class. """ if writer is None: writer = self._writer writer.write(f'{serialize(message)}\r\n'.encode('utf-8')) await writer.drain() async def _write_result(self, result, *, requester=None, uid='', writer=None, **ignored): # noqa """Write a result message to the stream. Parameters ---------- result The result of a request. Must be a JSON-serializable object, or have a to_json() method. requester : :class:`str`, optional The name of the device that sent the request. uid : :class:`str`, optional The unique identifier of the request. writer : :class:`asyncio.StreamWriter`, optional The writer to use to write the data. If not specified then uses the writer of this class. """ data = { 'error': False, 'requester': requester, 'result': result, 'uid': uid } await self._write(data, writer=writer) async def _write_error(self, error, *, requester=None, uid='', writer=None, **ignored): # noqa """Write an error message to the stream. Parameters ---------- error : :class:`Exception` An exception object. requester : :class:`str`, optional The name of the device that sent the request. uid : :class:`str`, optional The unique identifier of the request. writer : :class:`asyncio.StreamWriter`, optional The writer to use to write the data. If not specified then uses the writer of this class. """ e = traceback.format_exc() data = { 'error': True, 'message': f'{error.__class__.__name__}: {error}', 'requester': requester, 'result': None, 'traceback': [] if e.startswith('NoneType:') else e.splitlines(), 'uid': uid } await self._write(data, writer=writer)
[docs] class Device(Network): def __init__(self, name=None): """Base class for a :class:`~msl.network.service.Service` and :class:`~msl.network.client.Client`. .. versionadded:: 1.0 Parameters ---------- name : :class:`str`, optional The name of the device as it will appear on the Network :class:`~msl.network.manager.Manager`. If not specified then the class name is used. """ super(Device, self).__init__() self._address_manager = None self._name = self.__class__.__name__ if name is None else name self._password = None self._password_manager = None self._port = None self._queue = None self._tasks = [] self._username = None self._loop_thread_id = None @property def address_manager(self): """:class:`str`: The address of the :class:`~msl.network.manager.Manager` that this device is connected to.""" return self._address_manager @property def loop_thread_id(self): """Identifier of the thread running the event loop. Returns :data:`None` if the event loop is not running. .. versionadded:: 1.0 """ return self._loop_thread_id @property def name(self): """:class:`str`: The name of the device on the :class:`~msl.network.manager.Manager`.""" return self._name @property def port(self): """:class:`int`: The port number of this device that is being used for the connection to the :class:`~msl.network.manager.Manager`.""" return self._port
[docs] def add_tasks(self, *coros_or_futures): """Additional tasks to run in the event loop. .. versionadded:: 1.0 Parameters ---------- coros_or_futures Coroutines or futures that will be passed to :func:`asyncio.gather` when the event loop runs. """ self._tasks.extend(coros_or_futures)
[docs] def shutdown_handler(self): """Called after the connection to the Network :class:`~msl.network.manager.Manager` has been lost but before the event loop stops. Override this method to do any necessary cleanup. .. versionadded:: 1.0 """ pass
def _create_connection(self, **kwargs): self._address_manager = '{host}:{port}'.format(**kwargs) self._username = kwargs['username'] self._password = kwargs['password'] self._password_manager = kwargs['password_manager'] # get SSL context context = None if not kwargs['disable_tls']: # In Python 3.10, ssl.get_server_certificate() accepts a timeout parameter kws = {'timeout': kwargs['timeout']} if sys.version_info[:2] >= (3, 10) else {} try: cert_file, context = get_ssl_context( cert_file=kwargs['cert_file'], host=kwargs['host'], port=kwargs['port'], auto_save=kwargs['auto_save'], **kws ) except OSError as error: e = str(error) if ('WRONG_VERSION_NUMBER' in e) or ('UNKNOWN_PROTOCOL' in e): e += '\nTry setting disable_tls=True' elif kwargs['host'] in LOCALHOST_ALIASES: e += '\nMake sure a Manager is running on this computer' else: e += '\nCannot connect to {host}:{port} to get the ' \ 'certificate'.format(**kwargs) raise ConnectionError(e) from None if context is None: # then the user chose to not accept the SSL certificate return context.check_hostname = kwargs['assert_hostname'] logger.debug('loaded %s', cert_file) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) if kwargs['read_limit'] is None: kwargs['read_limit'] = sys.maxsize # connect try: self._reader, self._writer = loop.run_until_complete( asyncio.wait_for( asyncio.open_connection( host=kwargs['host'], port=kwargs['port'], ssl=context, limit=kwargs['read_limit'], ), kwargs['timeout'] ) ) except Exception as error: if isinstance(error, asyncio.TimeoutError): raise TimeoutError( 'Cannot connect to {host}:{port} within ' '{timeout} seconds'.format(**kwargs) ) from None msg = str(error) if msg.startswith('Multiple exceptions'): # comes from asyncio msg = 'Cannot connect to {host}:{port}'.format(**kwargs) elif isinstance(error, (ConnectionRefusedError, socket.gaierror)): msg += '\nCannot connect to {host}:{port}'.format(**kwargs) elif ('name mismatch' in msg) or ('address mismatch' in msg): msg += '\nTo disable hostname checking set assert_hostname=False\n' \ 'Make sure you trust the connection to {host}:{port} ' \ 'if you decide to do this.'.format(**kwargs) elif 'CERTIFICATE_VERIFY_FAILED' in msg: msg += '\nPerhaps the Network Manager is using a new certificate.\n' \ 'If you trust the connection to {host}:{port}, you can delete ' \ 'the certificate at\n {cert_file}\nand then re-connect to ' \ 'create a new trusted certificate.'.format(**kwargs) elif ('WRONG_VERSION_NUMBER' in msg) or ('UNKNOWN_PROTOCOL' in msg): msg += '\nTry setting disable_tls=True' elif 'nodename nor servname provided' in msg: host = kwargs['host'] msg += f'\nYou might need to add "{host} {HOSTNAME}" to /etc/hosts' raise ConnectionError(msg) from None # authenticate try: line = loop.run_until_complete( asyncio.wait_for(self._reader.readline(), kwargs['timeout'])) except asyncio.TimeoutError: msg = 'The connection to {host}:{port} was not established after ' \ '{timeout} second(s)'.format(**kwargs) if kwargs['disable_tls']: msg += '\nYou have TLS disabled. Perhaps the Manager is ' \ 'using TLS for the connection.' raise ConnectionError(msg) from None else: loop.run_until_complete(self._authenticate(line)) return loop def _run_until_complete(self): # Run all tasks until complete # must instantiate the Queue after the connection has been established # since self._create_connection creates a new event loop self._queue = asyncio.Queue() self._loop_thread_id = threading.get_ident() try: self._loop.run_until_complete(self._gather()) except KeyboardInterrupt: logger.debug('CTRL+C keyboard interrupt') except SystemExit: logger.debug('SystemExit raised') except Exception as e: logger.exception(e) finally: self._reader.feed_eof() self._loop.run_until_complete(self._queue.join()) self._writer.close() self._loop.run_until_complete(self._writer.wait_closed()) self._loop.close() self._loop_thread_id = None try: logger.info('disconnected from Manager[%s]', self._address_manager) except (NameError, ValueError): # These errors could occur when Python is exiting # ValueError: I/O operation on closed file # NameError: name 'open' is not defined pass async def _authenticate(self, line): # The Manager may ask for a username/password and will always request # the identity of the connecting device logger.debug('start authentication') while True: request = deserialize(line) if request['error']: raise ValueError(request['message']) identified = await self._handle_manager_request(request) if identified: break line = await self._reader.readline() logger.debug('finish authentication') async def _gather(self): # Gather all tasks await asyncio.gather(*self._tasks) async def _handle_manager_request(self, request): # Handle a request from a Manager logger.debug('Manager[%s] requested %r', self._address_manager, request['attribute']) if request['attribute'] == 'identity': await self._write_result(self._identity) self._port = int(self._writer.get_extra_info('sockname')[1]) self._network_name = f'{self._name}[{HOSTNAME}:{self._port}]' logger.info('connected to Manager[%s] as %s', self._address_manager, self._network_name) return True elif request['attribute'] == 'username': if self._username is None: name = request['args'][0] self._username = input(f'Enter a username for {name} > ') await self._write_result(self._username, **request) elif request['attribute'] == 'password': def get(): return getpass.getpass(f'Enter the password for {name} > ') name = request['args'][0] if _is_manager_regex.search(name) is not None: if self._password_manager is None: self._password_manager = get() password = self._password_manager else: if self._password is None: self._password = get() password = self._password await self._write_result(password, **request) else: assert False, 'should not get here!'