From c4cd26737adc648cb948ab76565349ba25d3ce4e Mon Sep 17 00:00:00 2001 From: kaliko Date: Sun, 25 Feb 2024 09:31:06 +0100 Subject: [PATCH] Add module mpdaio --- mpdaio-object.py | 10 +++ mpdaio-test.py | 54 +++++++++++++ mpdaio/__init__.py | 0 mpdaio/client.py | 111 +++++++++++++++++++++++++++ mpdaio/connection.py | 175 +++++++++++++++++++++++++++++++++++++++++++ mpdaio/exceptions.py | 14 ++++ 6 files changed, 364 insertions(+) create mode 100644 mpdaio-object.py create mode 100644 mpdaio-test.py create mode 100644 mpdaio/__init__.py create mode 100644 mpdaio/client.py create mode 100644 mpdaio/connection.py create mode 100644 mpdaio/exceptions.py diff --git a/mpdaio-object.py b/mpdaio-object.py new file mode 100644 index 0000000..2854f30 --- /dev/null +++ b/mpdaio-object.py @@ -0,0 +1,10 @@ +import asyncio + +from mpdaio.client import MPDClient + +async def run_cli(): + cli = MPDClient() + await cli.connect() + await cli.close() + +asyncio.run(run_cli()) diff --git a/mpdaio-test.py b/mpdaio-test.py new file mode 100644 index 0000000..255390e --- /dev/null +++ b/mpdaio-test.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 + +import logging + +from asyncio import run + +from mpdaio.connection import ConnectionPool +from mpdaio.exceptions import MPDProtocolError + + +HELLO_PREFIX = 'OK MPD ' + +async def _hello(conn): + """Consume HELLO_PREFIX""" + # await conn.drain() + # data = await conn.readline() + data = await conn.readuntil(b'\n') + rcv = data.decode('utf-8') + if not rcv.startswith(HELLO_PREFIX): + raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"') + logging.debug('consumed hello prefix') + logging.debug('"%s"', rcv) + version = rcv.split('\n')[0][len(HELLO_PREFIX):] + logging.debug('version: %s', version) + return version + + +async def lookup(pool, server, port, query): + try: + conn = await pool.connect(server, port) + logging.info(conn) + except (ValueError, OSError): + return {} + + async with conn: + await _hello(conn) + conn.write(query.encode('utf-8')) + conn.write(b'\n') + await conn.drain() + data = await conn.readuntil(b'OK\n') + rcv = data.decode('utf-8') + logging.info(rcv) + await pool.close() + + +def main(): + logging.basicConfig(level=logging.DEBUG) + pool = ConnectionPool(max_connections=10) + logging.info(pool) + run(lookup(pool, 'hispaniola.lan', 6600,'currentsong')) + + +if __name__ == '__main__': + main() diff --git a/mpdaio/__init__.py b/mpdaio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mpdaio/client.py b/mpdaio/client.py new file mode 100644 index 0000000..1d48da5 --- /dev/null +++ b/mpdaio/client.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: 2012-2024 kaliko +# SPDX-License-Identifier: LGPL-3.0-or-later + +import logging +import os + +from .connection import ConnectionPool, Connection +from .exceptions import MPDError, MPDConnectionError, MPDProtocolError + +HELLO_PREFIX = 'OK MPD ' +ERROR_PREFIX = 'ACK ' +SUCCESS = 'OK\n' +NEXT = 'list_OK' +#: Module version +VERSION = '0.9.0b2' +#: Seconds before a connection attempt times out +#: (overriden by :envvar:`MPD_TIMEOUT` env. var.) +CONNECTION_TIMEOUT = 30 +#: Socket timeout in second > 0 (Default is :py:obj:`None` for no timeout) +SOCKET_TIMEOUT = None +#: Maximum concurrent connections +CONNECTION_MAX = 10 + +logging.basicConfig(level=logging.DEBUG, + format='%(levelname)-8s %(module)-10s %(message)s') +log = logging.getLogger(__name__) + + +class MPDClient: + + def __init__(self,): + #: host used with the current connection (:py:obj:`str`) + self.host = None + #: password detected in :envvar:`MPD_HOST` environment variable (:py:obj:`str`) + self.pwd = None + #: port used with the current connection (:py:obj:`int`, :py:obj:`str`) + self.port = None + self._get_envvars() + self.pool = ConnectionPool(max_connections=CONNECTION_MAX) + log.info('logger : "%s"', __name__) + + def _get_envvars(self): + """ + Retrieve MPD env. var. to overrides default "localhost:6600" + """ + # Set some defaults + self.host = 'localhost' + self.port = os.getenv('MPD_PORT', '6600') + _host = os.getenv('MPD_HOST', '') + if _host: + # If password is set: MPD_HOST=pass@host + if '@' in _host: + mpd_host_env = _host.split('@', 1) + if mpd_host_env[0]: + # A password is actually set + log.debug( + 'password detected in MPD_HOST, set client pwd attribute') + self.pwd = mpd_host_env[0] + if mpd_host_env[1]: + self.host = mpd_host_env[1] + log.debug('host detected in MPD_HOST: %s', self.host) + elif mpd_host_env[1]: + # No password set but leading @ is an abstract socket + self.host = '@'+mpd_host_env[1] + log.debug( + 'host detected in MPD_HOST: %s (abstract socket)', self.host) + else: + # MPD_HOST is a plain host + self.host = _host + log.debug('host detected in MPD_HOST: @%s', self.host) + else: + # Is socket there + xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR', '/run') + rundir = os.path.join(xdg_runtime_dir, 'mpd/socket') + if os.path.exists(rundir): + self.host = rundir + log.debug( + 'host detected in ${XDG_RUNTIME_DIR}/run: %s (unix socket)', self.host) + _mpd_timeout = os.getenv('MPD_TIMEOUT', '') + if _mpd_timeout.isdigit(): + self.mpd_timeout = int(_mpd_timeout) + log.debug('timeout detected in MPD_TIMEOUT: %d', self.mpd_timeout) + else: # Use CONNECTION_TIMEOUT as default even if MPD_TIMEOUT carries gargage + self.mpd_timeout = CONNECTION_TIMEOUT + + async def _hello(self, conn): + """Consume HELLO_PREFIX""" + data = await conn.readuntil(b'\n') + rcv = data.decode('utf-8') + if not rcv.startswith(HELLO_PREFIX): + raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"') + log.debug('consumed hello prefix: "%s"', rcv) + self.version = rcv.split('\n')[0][len(HELLO_PREFIX):] + log.info('protocol version: %s', self.version) + + async def connect(self, server=None, port=None) -> Connection: + if not server: + server = self.host + if not port: + port = self.port + try: + conn = await self.pool.connect(server, port) + except (ValueError, OSError) as err: + raise MPDConnectionError(err) from err + async with conn: + await self._hello(conn) + + async def close(self): + await self.pool.close() + diff --git a/mpdaio/connection.py b/mpdaio/connection.py new file mode 100644 index 0000000..32ad65f --- /dev/null +++ b/mpdaio/connection.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: 2012-2024 kaliko +# SPDX-License-Identifier: LGPL-3.0-or-later +"""https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6 +""" + +import asyncio +import contextlib +import logging + +from collections import OrderedDict +from types import TracebackType +from typing import Any, List, Optional, Tuple, Type + + +try: # Python 3.7 + base = contextlib.AbstractAsyncContextManager +except AttributeError as err: + base = object # type: ignore + +Server = str +Port = int +Host = Tuple[Server, Port] +log = logging.getLogger(__name__) + + +class ConnectionPool(base): + def __init__( + self, + max_connections: int = 1000, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + self.max_connections = max_connections + self._loop = loop or asyncio.get_event_loop() + + self._connections: OrderedDict[Host, + List["Connection"]] = OrderedDict() + self._semaphore = asyncio.Semaphore(max_connections) + + async def connect(self, server: Server, port: Port) -> "Connection": + host = (server, port) + + # enforce the connection limit, releasing connections notifies + # the semaphore to release here + await self._semaphore.acquire() + + connections = self._connections.setdefault(host, []) + log.info('got %s in pool', len(connections)) + # find an un-used connection for this host + connection = next( + (conn for conn in connections if not conn.in_use), None) + if connection is None: + # disconnect the least-recently-used un-used connection to make space + # for a new connection. There will be at least one. + for conns_per_host in reversed(self._connections.values()): + for conn in conns_per_host: + if not conn.in_use: + await conn.close() + break + + reader, writer = await asyncio.open_connection(server, port) + connection = Connection(self, host, reader, writer) + connections.append(connection) + + connection.in_use = True + # move current host to the front as most-recently used + self._connections.move_to_end(host, False) + + return connection + + async def close(self): + """Close all connections""" + connections = [c for cs in self._connections.values() for c in cs] + self._connections = OrderedDict() + for connection in connections: + await connection.close() + + def _remove(self, connection): + conns_for_host = self._connections.get(connection._host) + if not conns_for_host: + return + conns_for_host[:] = [c for c in conns_for_host if c != connection] + + def _notify_release(self): + self._semaphore.release() + + async def __aenter__(self) -> "ConnectionPool": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + await self.close() + + def __del__(self) -> None: + connections = [repr(c) + for cs in self._connections.values() for c in cs] + if not connections: + return + + context = { + "pool": self, + "connections": connections, + "message": "Unclosed connection pool", + } + self._loop.call_exception_handler(context) + + +class Connection(base): + def __init__( + self, + pool: ConnectionPool, + host: Host, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + self._host = host + self._pool = pool + self._reader = reader + self._writer = writer + self._closed = False + self.in_use = False + + def __repr__(self): + host = f"{self._host[0]}:{self._host[1]}" + return f"Connection<{host}>" + + @property + def closed(self): + return self._closed + + def release(self) -> None: + self.in_use = False + self._pool._notify_release() + + async def close(self) -> None: + if self._closed: + return + self._closed = True + self._writer.close() + self._pool._remove(self) + try: + await self._writer.wait_closed() + except AttributeError: # wait_closed is new in 3.7 + pass + + def __getattr__(self, name: str) -> Any: + """All unknown attributes are delegated to the reader and writer""" + if self._closed or not self.in_use: + raise ValueError("Can't use a closed or unacquired connection") + if hasattr(self._reader, name): + return getattr(self._reader, name) + return getattr(self._writer, name) + + async def __aenter__(self) -> "Connection": + if self._closed or not self.in_use: + raise ValueError("Can't use a closed or unacquired connection") + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + self.release() + + def __del__(self) -> None: + if self._closed: + return + context = {"connection": self, "message": "Unclosed connection"} + self._pool._loop.call_exception_handler(context) diff --git a/mpdaio/exceptions.py b/mpdaio/exceptions.py new file mode 100644 index 0000000..b86cd9a --- /dev/null +++ b/mpdaio/exceptions.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: 2012-2024 kaliko +# SPDX-License-Identifier: LGPL-3.0-or-later + +class MPDError(Exception): + """Main musicpd Exception""" + + +class MPDConnectionError(MPDError): + """Fatal Connection Error, cannot recover from it.""" + + +class MPDProtocolError(MPDError): + """Fatal Protocol Error, cannot recover from it""" -- 2.39.5