From 96255f59a1d8f380b0e3c028fd2ac5fa4d76f778 Mon Sep 17 00:00:00 2001 From: kaliko Date: Sun, 25 Feb 2024 17:28:28 +0100 Subject: [PATCH] Partially implement protocol (AsyncIO POC) --- mpdaio-object.py | 3 + mpdaio/client.py | 393 +++++++++++++++++++++++++++++++++++++++++-- mpdaio/connection.py | 4 +- mpdaio/exceptions.py | 3 + mpdaio/utils.py | 47 ++++++ 5 files changed, 437 insertions(+), 13 deletions(-) create mode 100644 mpdaio/utils.py diff --git a/mpdaio-object.py b/mpdaio-object.py index 2854f30..0691bb2 100644 --- a/mpdaio-object.py +++ b/mpdaio-object.py @@ -5,6 +5,9 @@ from mpdaio.client import MPDClient async def run_cli(): cli = MPDClient() await cli.connect() + current = await cli.currentsong() + print(current) + print(await cli.playlistinfo()) await cli.close() asyncio.run(run_cli()) diff --git a/mpdaio/client.py b/mpdaio/client.py index 1d48da5..190f809 100644 --- a/mpdaio/client.py +++ b/mpdaio/client.py @@ -6,11 +6,12 @@ import logging import os from .connection import ConnectionPool, Connection -from .exceptions import MPDError, MPDConnectionError, MPDProtocolError +from .exceptions import MPDConnectionError, MPDProtocolError, MPDCommandError +from .utils import Range, escape HELLO_PREFIX = 'OK MPD ' ERROR_PREFIX = 'ACK ' -SUCCESS = 'OK\n' +SUCCESS = 'OK' NEXT = 'list_OK' #: Module version VERSION = '0.9.0b2' @@ -30,6 +31,137 @@ log = logging.getLogger(__name__) class MPDClient: def __init__(self,): + self._commands = { + # Status Commands + "clearerror": self._fetch_nothing, + "currentsong": self._fetch_object, + "idle": self._fetch_list, + # "noidle": None, + "status": self._fetch_object, + "stats": self._fetch_object, + # Playback Option Commands + "consume": self._fetch_nothing, + "crossfade": self._fetch_nothing, + "mixrampdb": self._fetch_nothing, + "mixrampdelay": self._fetch_nothing, + "random": self._fetch_nothing, + "repeat": self._fetch_nothing, + "setvol": self._fetch_nothing, + "getvol": self._fetch_object, + "single": self._fetch_nothing, + "replay_gain_mode": self._fetch_nothing, + "replay_gain_status": self._fetch_item, + "volume": self._fetch_nothing, + # Playback Control Commands + "next": self._fetch_nothing, + "pause": self._fetch_nothing, + "play": self._fetch_nothing, + "playid": self._fetch_nothing, + "previous": self._fetch_nothing, + "seek": self._fetch_nothing, + "seekid": self._fetch_nothing, + "seekcur": self._fetch_nothing, + "stop": self._fetch_nothing, + # Queue Commands + "add": self._fetch_nothing, + "addid": self._fetch_item, + "clear": self._fetch_nothing, + "delete": self._fetch_nothing, + "deleteid": self._fetch_nothing, + "move": self._fetch_nothing, + "moveid": self._fetch_nothing, + "playlist": self._fetch_playlist, + "playlistfind": self._fetch_songs, + "playlistid": self._fetch_songs, + "playlistinfo": self._fetch_songs, + "playlistsearch": self._fetch_songs, + "plchanges": self._fetch_songs, + "plchangesposid": self._fetch_changes, + "prio": self._fetch_nothing, + "prioid": self._fetch_nothing, + "rangeid": self._fetch_nothing, + "shuffle": self._fetch_nothing, + "swap": self._fetch_nothing, + "swapid": self._fetch_nothing, + "addtagid": self._fetch_nothing, + "cleartagid": self._fetch_nothing, + # Stored Playlist Commands + "listplaylist": self._fetch_list, + "listplaylistinfo": self._fetch_songs, + "listplaylists": self._fetch_playlists, + "load": self._fetch_nothing, + "playlistadd": self._fetch_nothing, + "playlistclear": self._fetch_nothing, + "playlistdelete": self._fetch_nothing, + "playlistmove": self._fetch_nothing, + "rename": self._fetch_nothing, + "rm": self._fetch_nothing, + "save": self._fetch_nothing, + # Database Commands + "albumart": self._fetch_composite, + "count": self._fetch_object, + "getfingerprint": self._fetch_object, + "find": self._fetch_songs, + "findadd": self._fetch_nothing, + "list": self._fetch_list, + "listall": self._fetch_database, + "listallinfo": self._fetch_database, + "listfiles": self._fetch_database, + "lsinfo": self._fetch_database, + "readcomments": self._fetch_object, + "readpicture": self._fetch_composite, + "search": self._fetch_songs, + "searchadd": self._fetch_nothing, + "searchaddpl": self._fetch_nothing, + "update": self._fetch_item, + "rescan": self._fetch_item, + # Mounts and neighbors + "mount": self._fetch_nothing, + "unmount": self._fetch_nothing, + "listmounts": self._fetch_mounts, + "listneighbors": self._fetch_neighbors, + # Sticker Commands + "sticker get": self._fetch_item, + "sticker set": self._fetch_nothing, + "sticker delete": self._fetch_nothing, + "sticker list": self._fetch_list, + "sticker find": self._fetch_songs, + # Connection Commands + "close": None, + "kill": None, + "password": self._fetch_nothing, + "ping": self._fetch_nothing, + "binarylimit": self._fetch_nothing, + "tagtypes": self._fetch_list, + "tagtypes disable": self._fetch_nothing, + "tagtypes enable": self._fetch_nothing, + "tagtypes clear": self._fetch_nothing, + "tagtypes all": self._fetch_nothing, + # Partition Commands + "partition": self._fetch_nothing, + "listpartitions": self._fetch_list, + "newpartition": self._fetch_nothing, + "delpartition": self._fetch_nothing, + "moveoutput": self._fetch_nothing, + # Audio Output Commands + "disableoutput": self._fetch_nothing, + "enableoutput": self._fetch_nothing, + "toggleoutput": self._fetch_nothing, + "outputs": self._fetch_outputs, + "outputset": self._fetch_nothing, + # Reflection Commands + "config": self._fetch_object, + "commands": self._fetch_list, + "notcommands": self._fetch_list, + "urlhandlers": self._fetch_list, + "decoders": self._fetch_plugins, + # Client to Client + "subscribe": self._fetch_nothing, + "unsubscribe": self._fetch_nothing, + "channels": self._fetch_list, + "readmessages": self._fetch_messages, + "sendmessage": self._fetch_nothing, + } #: host used with the current connection (:py:obj:`str`) self.host = None #: password detected in :envvar:`MPD_HOST` environment variable (:py:obj:`str`) @@ -37,8 +169,13 @@ class MPDClient: #: port used with the current connection (:py:obj:`int`, :py:obj:`str`) self.port = None self._get_envvars() - self.pool = ConnectionPool(max_connections=CONNECTION_MAX) + self._pool = ConnectionPool(max_connections=CONNECTION_MAX) log.info('logger : "%s"', __name__) + #: current connection + self.connection: [None,Connection] = None + #: Protocol version + self.version: [None,str] = None + self._command_list = None def _get_envvars(self): """ @@ -68,7 +205,7 @@ class MPDClient: else: # MPD_HOST is a plain host self.host = _host - log.debug('host detected in MPD_HOST: @%s', self.host) + log.debug('host detected in MPD_HOST: %s', self.host) else: # Is socket there xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR', '/run') @@ -84,13 +221,245 @@ class MPDClient: else: # Use CONNECTION_TIMEOUT as default even if MPD_TIMEOUT carries gargage self.mpd_timeout = CONNECTION_TIMEOUT - async def _hello(self, conn): + def __getattr__(self, attr): + # if attr == 'send_noidle': # have send_noidle to cancel idle as well as noidle + # return self.noidle + if attr.startswith("send_"): + command = attr.replace("send_", "", 1) + wrapper = self._send + elif attr.startswith("fetch_"): + command = attr.replace("fetch_", "", 1) + wrapper = self._fetch + else: + command = attr + wrapper = self._execute + if command not in self._commands: + command = command.replace("_", " ") + if command not in self._commands: + cls = self.__class__.__name__ + raise AttributeError( + f"'{cls}' object has no attribute '{attr}'") + return lambda *args: wrapper(command, args) + + async def _execute(self, command, args): # pylint: disable=unused-argument + self.connection = await self._pool.connect(self.host, self.port) + async with self.connection: + # if self._pending: + # raise MPDCommandError( + # f"Cannot execute '{command}' with pending commands") + retval = self._commands[command] + if self._command_list is not None: + if not callable(retval): + raise MPDCommandError( + f"'{command}' not allowed in command list") + self._write_command(command, args) + self._command_list.append(retval) + else: + await self._write_command(command, args) + if callable(retval): + log.debug('retvat: %s', retval) + return await retval() + return retval + return None + + async def _write_line(self, line): + self.connection.write(f"{line!s}\n".encode()) + await self.connection.drain() + + async def _write_command(self, command, args=None): + if args is None: + args = [] + parts = [command] + for arg in args: + if isinstance(arg, tuple): + parts.append(f'{Range(arg)!s}') + else: + parts.append(f'"{escape(str(arg))}"') + if '\n' in ' '.join(parts): + raise MPDCommandError('new line found in the command!') + await self._write_line(' '.join(parts)) + + def _read_binary(self, amount): + chunk = bytearray() + while amount > 0: + result = self._rbfile.read(amount) + if len(result) == 0: + self.disconnect() + raise ConnectionError( + "Connection lost while reading binary content") + chunk.extend(result) + amount -= len(result) + return bytes(chunk) + + async def _read_line(self, binary=False): + if binary: + line = self._rbfile.readline().decode('utf-8') + else: + line = await self.connection.readline() + line = line.decode('utf-8') + if not line.endswith('\n'): + await self.close() + raise MPDConnectionError("Connection lost while reading line") + line = line.rstrip('\n') + if line.startswith(ERROR_PREFIX): + error = line[len(ERROR_PREFIX):].strip() + raise MPDCommandError(error) + if self._command_list is not None: + if line == NEXT: + return None + if line == SUCCESS: + raise MPDProtocolError(f"Got unexpected '{SUCCESS}'") + elif line == SUCCESS: + return None + return line + + async def _read_pair(self, separator, binary=False): + line = await self._read_line(binary=binary) + if line is None: + return None + pair = line.split(separator, 1) + if len(pair) < 2: + raise MPDProtocolError(f"Could not parse pair: '{line}'") + return pair + + async def _read_pairs(self, separator=": ", binary=False): + """OK""" + pair = await self._read_pair(separator, binary=binary) + while pair: + yield pair + pair = await self._read_pair(separator, binary=binary) + + async def _read_list(self): + seen = None + for key, value in await self._read_pairs(): + if key != seen: + if seen is not None: + raise MPDProtocolError(f"Expected key '{seen}', got '{key}'") + seen = key + yield value + + async def _read_playlist(self): + for _, value in await self._read_pairs(":"): + yield value + + async def _read_objects(self, delimiters=None): + obj = {} + if delimiters is None: + delimiters = [] + async for key, value in self._read_pairs(): + key = key.lower() + if obj: + if key in delimiters: + yield obj + obj = {} + elif key in obj: + if not isinstance(obj[key], list): + obj[key] = [obj[key], value] + else: + obj[key].append(value) + continue + obj[key] = value + if obj: + yield obj + + def _read_command_list(self): + try: + for retval in self._command_list: + yield retval() + finally: + self._command_list = None + self._fetch_nothing() + + async def _fetch_nothing(self): + line = await self._read_line() + if line is not None: + raise ProtocolError(f"Got unexpected return value: '{line}'") + + async def _fetch_item(self): + pairs = list(await self._read_pairs()) + if len(pairs) != 1: + return None + return pairs[0][1] + + def _fetch_list(self): + return self._read_list() + + def _fetch_playlist(self): + return self._read_playlist() + + async def _fetch_object(self): + objs = [obj async for obj in self._read_objects()] + if not objs: + return {} + return objs[0] + + async def _fetch_objects(self, delimiters): + return [_ async for _ in self._read_objects(delimiters)] + + def _fetch_changes(self): + return self._fetch_objects(["cpos"]) + + async def _fetch_songs(self): + return await self._fetch_objects(["file"]) + + def _fetch_playlists(self): + return self._fetch_objects(["playlist"]) + + def _fetch_database(self): + return self._fetch_objects(["file", "directory", "playlist"]) + + def _fetch_outputs(self): + return self._fetch_objects(["outputid"]) + + def _fetch_plugins(self): + return self._fetch_objects(["plugin"]) + + def _fetch_messages(self): + return self._fetch_objects(["channel"]) + + def _fetch_mounts(self): + return self._fetch_objects(["mount"]) + + def _fetch_neighbors(self): + return self._fetch_objects(["neighbor"]) + + async def _fetch_composite(self): + obj = {} + for key, value in self._read_pairs(binary=True): + key = key.lower() + obj[key] = value + if key == 'binary': + break + if not obj: + # If the song file was recognized, but there is no picture, the + # response is successful, but is otherwise empty. + return obj + amount = int(obj['binary']) + try: + obj['data'] = self._read_binary(amount) + except IOError as err: + raise ConnectionError( + f'Error reading binary content: {err}') from err + data_bytes = len(obj['data']) + if data_bytes != amount: # can we ever get there? + raise ConnectionError('Error reading binary content: ' + f'Expects {amount}B, got {data_bytes}') + # Fetches trailing new line + await self._read_line(binary=True) + # Fetches SUCCESS code + await self._read_line(binary=True) + return obj + + def _fetch_command_list(self): + return self._read_command_list() + + async def _hello(self): """Consume HELLO_PREFIX""" - data = await conn.readuntil(b'\n') + data = await self.connection.readuntil(b'\n') rcv = data.decode('utf-8') if not rcv.startswith(HELLO_PREFIX): raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"') - log.debug('consumed hello prefix: "%s"', rcv) + log.debug('consumed hello prefix: %r', rcv) self.version = rcv.split('\n')[0][len(HELLO_PREFIX):] log.info('protocol version: %s', self.version) @@ -100,12 +469,12 @@ class MPDClient: if not port: port = self.port try: - conn = await self.pool.connect(server, port) + self.connection = await self._pool.connect(server, port) except (ValueError, OSError) as err: + #TODO: Is this the right way to raise Excep raise MPDConnectionError(err) from err - async with conn: - await self._hello(conn) + async with self.connection: + await self._hello() async def close(self): - await self.pool.close() - + await self.connection.close() diff --git a/mpdaio/connection.py b/mpdaio/connection.py index 32ad65f..d6f025b 100644 --- a/mpdaio/connection.py +++ b/mpdaio/connection.py @@ -45,7 +45,7 @@ class ConnectionPool(base): await self._semaphore.acquire() connections = self._connections.setdefault(host, []) - log.info('got %s in pool', len(connections)) + log.debug('got %s in pool', len(connections)) # find an un-used connection for this host connection = next( (conn for conn in connections if not conn.in_use), None) @@ -133,12 +133,14 @@ class Connection(base): return self._closed def release(self) -> None: + logging.debug('releasing %r', self) self.in_use = False self._pool._notify_release() async def close(self) -> None: if self._closed: return + logging.debug('closing %r', self) self._closed = True self._writer.close() self._pool._remove(self) diff --git a/mpdaio/exceptions.py b/mpdaio/exceptions.py index b86cd9a..1ddf340 100644 --- a/mpdaio/exceptions.py +++ b/mpdaio/exceptions.py @@ -12,3 +12,6 @@ class MPDConnectionError(MPDError): class MPDProtocolError(MPDError): """Fatal Protocol Error, cannot recover from it""" + +class MPDCommandError(MPDError): + """Malformed command, socket should be fine, can reuse it""" diff --git a/mpdaio/utils.py b/mpdaio/utils.py new file mode 100644 index 0000000..1d01db5 --- /dev/null +++ b/mpdaio/utils.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: 2012-2024 kaliko +# SPDX-License-Identifier: LGPL-3.0-or-later + +class Range: + + def __init__(self, tpl): + self.tpl = tpl + self.lower = '' + self.upper = '' + self._check() + + def __str__(self): + return f'{self.lower}:{self.upper}' + + def __repr__(self): + return f'Range({self.tpl})' + + def _check_element(self, item): + if item is None or item == '': + return '' + try: + return str(int(item)) + except (TypeError, ValueError) as err: + raise CommandError(f'Not an integer: "{item}"') from err + return item + + def _check(self): + if not isinstance(self.tpl, tuple): + raise CommandError('Wrong type, provide a tuple') + if len(self.tpl) == 0: + return + if len(self.tpl) == 1: + self.lower = self._check_element(self.tpl[0]) + return + if len(self.tpl) != 2: + raise CommandError('Range wrong size (0, 1 or 2 allowed)') + self.lower = self._check_element(self.tpl[0]) + self.upper = self._check_element(self.tpl[1]) + if self.lower == '' and self.upper != '': + raise CommandError(f'Integer expected to start the range: {self.tpl}') + if self.upper.isdigit() and self.lower.isdigit(): + if int(self.lower) > int(self.upper): + raise CommandError(f'Wrong range: {self.lower} > {self.upper}') + +def escape(text): + return text.replace("\\", "\\\\").replace('"', '\\"') -- 2.39.2