]> kaliko git repositories - python-musicpdaio.git/commitdiff
Add module mpdaio
authorkaliko <kaliko@azylum.org>
Sun, 25 Feb 2024 08:31:06 +0000 (09:31 +0100)
committerkaliko <kaliko@azylum.org>
Sun, 25 Feb 2024 08:31:06 +0000 (09:31 +0100)
mpdaio-object.py [new file with mode: 0644]
mpdaio-test.py [new file with mode: 0644]
mpdaio/__init__.py [new file with mode: 0644]
mpdaio/client.py [new file with mode: 0644]
mpdaio/connection.py [new file with mode: 0644]
mpdaio/exceptions.py [new file with mode: 0644]

diff --git a/mpdaio-object.py b/mpdaio-object.py
new file mode 100644 (file)
index 0000000..2854f30
--- /dev/null
@@ -0,0 +1,10 @@
+import asyncio
+
+from mpdaio.client import MPDClient
+
+async def run_cli():
+    cli = MPDClient()
+    await cli.connect()
+    await cli.close()
+
+asyncio.run(run_cli())
diff --git a/mpdaio-test.py b/mpdaio-test.py
new file mode 100644 (file)
index 0000000..255390e
--- /dev/null
@@ -0,0 +1,54 @@
+#!/usr/bin/python3
+
+import logging
+
+from asyncio import run
+
+from mpdaio.connection import ConnectionPool
+from mpdaio.exceptions import MPDProtocolError
+
+
+HELLO_PREFIX = 'OK MPD '
+
+async def _hello(conn):
+    """Consume HELLO_PREFIX"""
+    # await conn.drain()
+    # data = await conn.readline()
+    data = await conn.readuntil(b'\n')
+    rcv = data.decode('utf-8')
+    if not rcv.startswith(HELLO_PREFIX):
+        raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
+    logging.debug('consumed hello prefix')
+    logging.debug('"%s"', rcv)
+    version = rcv.split('\n')[0][len(HELLO_PREFIX):]
+    logging.debug('version: %s', version)
+    return version
+
+
+async def lookup(pool, server, port, query):
+    try:
+        conn = await pool.connect(server, port)
+        logging.info(conn)
+    except (ValueError, OSError):
+        return {}
+
+    async with conn:
+        await _hello(conn)
+        conn.write(query.encode('utf-8'))
+        conn.write(b'\n')
+        await conn.drain()
+        data = await conn.readuntil(b'OK\n')
+        rcv = data.decode('utf-8')
+        logging.info(rcv)
+    await pool.close()
+
+
+def main():
+    logging.basicConfig(level=logging.DEBUG)
+    pool = ConnectionPool(max_connections=10)
+    logging.info(pool)
+    run(lookup(pool, 'hispaniola.lan', 6600,'currentsong'))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/mpdaio/__init__.py b/mpdaio/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/mpdaio/client.py b/mpdaio/client.py
new file mode 100644 (file)
index 0000000..1d48da5
--- /dev/null
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+# SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
+# SPDX-License-Identifier: LGPL-3.0-or-later
+
+import logging
+import os
+
+from .connection import ConnectionPool, Connection
+from .exceptions import MPDError, MPDConnectionError, MPDProtocolError
+
+HELLO_PREFIX = 'OK MPD '
+ERROR_PREFIX = 'ACK '
+SUCCESS = 'OK\n'
+NEXT = 'list_OK'
+#: Module version
+VERSION = '0.9.0b2'
+#: Seconds before a connection attempt times out
+#: (overriden by :envvar:`MPD_TIMEOUT` env. var.)
+CONNECTION_TIMEOUT = 30
+#: Socket timeout in second > 0 (Default is :py:obj:`None` for no timeout)
+SOCKET_TIMEOUT = None
+#: Maximum concurrent connections
+CONNECTION_MAX = 10
+
+logging.basicConfig(level=logging.DEBUG,
+                    format='%(levelname)-8s %(module)-10s %(message)s')
+log = logging.getLogger(__name__)
+
+
+class MPDClient:
+
+    def __init__(self,):
+        #: host used with the current connection (:py:obj:`str`)
+        self.host = None
+        #: password detected in :envvar:`MPD_HOST` environment variable (:py:obj:`str`)
+        self.pwd = None
+        #: port used with the current connection (:py:obj:`int`, :py:obj:`str`)
+        self.port = None
+        self._get_envvars()
+        self.pool = ConnectionPool(max_connections=CONNECTION_MAX)
+        log.info('logger : "%s"', __name__)
+
+    def _get_envvars(self):
+        """
+        Retrieve MPD env. var. to overrides default "localhost:6600"
+        """
+        # Set some defaults
+        self.host = 'localhost'
+        self.port = os.getenv('MPD_PORT', '6600')
+        _host = os.getenv('MPD_HOST', '')
+        if _host:
+            # If password is set: MPD_HOST=pass@host
+            if '@' in _host:
+                mpd_host_env = _host.split('@', 1)
+                if mpd_host_env[0]:
+                    # A password is actually set
+                    log.debug(
+                        'password detected in MPD_HOST, set client pwd attribute')
+                    self.pwd = mpd_host_env[0]
+                    if mpd_host_env[1]:
+                        self.host = mpd_host_env[1]
+                        log.debug('host detected in MPD_HOST: %s', self.host)
+                elif mpd_host_env[1]:
+                    # No password set but leading @ is an abstract socket
+                    self.host = '@'+mpd_host_env[1]
+                    log.debug(
+                        'host detected in MPD_HOST: %s (abstract socket)', self.host)
+            else:
+                # MPD_HOST is a plain host
+                self.host = _host
+                log.debug('host detected in MPD_HOST: @%s', self.host)
+        else:
+            # Is socket there
+            xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR', '/run')
+            rundir = os.path.join(xdg_runtime_dir, 'mpd/socket')
+            if os.path.exists(rundir):
+                self.host = rundir
+                log.debug(
+                    'host detected in ${XDG_RUNTIME_DIR}/run: %s (unix socket)', self.host)
+        _mpd_timeout = os.getenv('MPD_TIMEOUT', '')
+        if _mpd_timeout.isdigit():
+            self.mpd_timeout = int(_mpd_timeout)
+            log.debug('timeout detected in MPD_TIMEOUT: %d', self.mpd_timeout)
+        else:  # Use CONNECTION_TIMEOUT as default even if MPD_TIMEOUT carries gargage
+            self.mpd_timeout = CONNECTION_TIMEOUT
+
+    async def _hello(self, conn):
+        """Consume HELLO_PREFIX"""
+        data = await conn.readuntil(b'\n')
+        rcv = data.decode('utf-8')
+        if not rcv.startswith(HELLO_PREFIX):
+            raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
+        log.debug('consumed hello prefix: "%s"', rcv)
+        self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
+        log.info('protocol version: %s', self.version)
+
+    async def connect(self, server=None, port=None) -> Connection:
+        if not server:
+            server = self.host
+        if not port:
+            port = self.port
+        try:
+            conn = await self.pool.connect(server, port)
+        except (ValueError, OSError) as err:
+            raise MPDConnectionError(err) from err
+        async with conn:
+            await self._hello(conn)
+
+    async def close(self):
+        await self.pool.close()
+
diff --git a/mpdaio/connection.py b/mpdaio/connection.py
new file mode 100644 (file)
index 0000000..32ad65f
--- /dev/null
@@ -0,0 +1,175 @@
+# -*- coding: utf-8 -*-
+# SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
+# SPDX-License-Identifier: LGPL-3.0-or-later
+"""https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
+"""
+
+import asyncio
+import contextlib
+import logging
+
+from collections import OrderedDict
+from types import TracebackType
+from typing import Any, List, Optional, Tuple, Type
+
+
+try:  # Python 3.7
+    base = contextlib.AbstractAsyncContextManager
+except AttributeError as err:
+    base = object  # type: ignore
+
+Server = str
+Port = int
+Host = Tuple[Server, Port]
+log = logging.getLogger(__name__)
+
+
+class ConnectionPool(base):
+    def __init__(
+        self,
+        max_connections: int = 1000,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+    ):
+        self.max_connections = max_connections
+        self._loop = loop or asyncio.get_event_loop()
+
+        self._connections: OrderedDict[Host,
+                                       List["Connection"]] = OrderedDict()
+        self._semaphore = asyncio.Semaphore(max_connections)
+
+    async def connect(self, server: Server, port: Port) -> "Connection":
+        host = (server, port)
+
+        # enforce the connection limit, releasing connections notifies
+        # the semaphore to release here
+        await self._semaphore.acquire()
+
+        connections = self._connections.setdefault(host, [])
+        log.info('got %s in pool', len(connections))
+        # find an un-used connection for this host
+        connection = next(
+            (conn for conn in connections if not conn.in_use), None)
+        if connection is None:
+            # disconnect the least-recently-used un-used connection to make space
+            # for a new connection. There will be at least one.
+            for conns_per_host in reversed(self._connections.values()):
+                for conn in conns_per_host:
+                    if not conn.in_use:
+                        await conn.close()
+                        break
+
+            reader, writer = await asyncio.open_connection(server, port)
+            connection = Connection(self, host, reader, writer)
+            connections.append(connection)
+
+        connection.in_use = True
+        # move current host to the front as most-recently used
+        self._connections.move_to_end(host, False)
+
+        return connection
+
+    async def close(self):
+        """Close all connections"""
+        connections = [c for cs in self._connections.values() for c in cs]
+        self._connections = OrderedDict()
+        for connection in connections:
+            await connection.close()
+
+    def _remove(self, connection):
+        conns_for_host = self._connections.get(connection._host)
+        if not conns_for_host:
+            return
+        conns_for_host[:] = [c for c in conns_for_host if c != connection]
+
+    def _notify_release(self):
+        self._semaphore.release()
+
+    async def __aenter__(self) -> "ConnectionPool":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
+        await self.close()
+
+    def __del__(self) -> None:
+        connections = [repr(c)
+                       for cs in self._connections.values() for c in cs]
+        if not connections:
+            return
+
+        context = {
+            "pool": self,
+            "connections": connections,
+            "message": "Unclosed connection pool",
+        }
+        self._loop.call_exception_handler(context)
+
+
+class Connection(base):
+    def __init__(
+        self,
+        pool: ConnectionPool,
+        host: Host,
+        reader: asyncio.StreamReader,
+        writer: asyncio.StreamWriter,
+    ):
+        self._host = host
+        self._pool = pool
+        self._reader = reader
+        self._writer = writer
+        self._closed = False
+        self.in_use = False
+
+    def __repr__(self):
+        host = f"{self._host[0]}:{self._host[1]}"
+        return f"Connection<{host}>"
+
+    @property
+    def closed(self):
+        return self._closed
+
+    def release(self) -> None:
+        self.in_use = False
+        self._pool._notify_release()
+
+    async def close(self) -> None:
+        if self._closed:
+            return
+        self._closed = True
+        self._writer.close()
+        self._pool._remove(self)
+        try:
+            await self._writer.wait_closed()
+        except AttributeError:  # wait_closed is new in 3.7
+            pass
+
+    def __getattr__(self, name: str) -> Any:
+        """All unknown attributes are delegated to the reader and writer"""
+        if self._closed or not self.in_use:
+            raise ValueError("Can't use a closed or unacquired connection")
+        if hasattr(self._reader, name):
+            return getattr(self._reader, name)
+        return getattr(self._writer, name)
+
+    async def __aenter__(self) -> "Connection":
+        if self._closed or not self.in_use:
+            raise ValueError("Can't use a closed or unacquired connection")
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
+        self.release()
+
+    def __del__(self) -> None:
+        if self._closed:
+            return
+        context = {"connection": self, "message": "Unclosed connection"}
+        self._pool._loop.call_exception_handler(context)
diff --git a/mpdaio/exceptions.py b/mpdaio/exceptions.py
new file mode 100644 (file)
index 0000000..b86cd9a
--- /dev/null
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
+# SPDX-License-Identifier: LGPL-3.0-or-later
+
+class MPDError(Exception):
+    """Main musicpd Exception"""
+
+
+class MPDConnectionError(MPDError):
+    """Fatal Connection Error, cannot recover from it."""
+
+
+class MPDProtocolError(MPDError):
+    """Fatal Protocol Error, cannot recover from it"""