X-Git-Url: https://git.kaliko.me/?p=python-musicpdaio.git;a=blobdiff_plain;f=mpdaio%2Fclient.py;h=2b971627e6d604dcd4046f524eb979c1d0cb34b2;hp=190f809f19ba1bb89675c92153ecca1167d8303f;hb=HEAD;hpb=96255f59a1d8f380b0e3c028fd2ac5fa4d76f778 diff --git a/mpdaio/client.py b/mpdaio/client.py index 190f809..2b97162 100644 --- a/mpdaio/client.py +++ b/mpdaio/client.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # SPDX-FileCopyrightText: 2012-2024 kaliko +# SPDX-FileCopyrightText: 2008-2010 J. Alexander Treuman # SPDX-License-Identifier: LGPL-3.0-or-later import logging @@ -9,34 +10,137 @@ from .connection import ConnectionPool, Connection from .exceptions import MPDConnectionError, MPDProtocolError, MPDCommandError from .utils import Range, escape -HELLO_PREFIX = 'OK MPD ' -ERROR_PREFIX = 'ACK ' -SUCCESS = 'OK' -NEXT = 'list_OK' -#: Module version -VERSION = '0.9.0b2' -#: Seconds before a connection attempt times out -#: (overriden by :envvar:`MPD_TIMEOUT` env. var.) -CONNECTION_TIMEOUT = 30 -#: Socket timeout in second > 0 (Default is :py:obj:`None` for no timeout) -SOCKET_TIMEOUT = None -#: Maximum concurrent connections -CONNECTION_MAX = 10 - -logging.basicConfig(level=logging.DEBUG, - format='%(levelname)-8s %(module)-10s %(message)s') +from .const import CONNECTION_MAX, CONNECTION_TIMEOUT +from .const import HELLO_PREFIX, ERROR_PREFIX, SUCCESS, NEXT + log = logging.getLogger(__name__) class MPDClient: + """:synopsis: Main class to instanciate building an MPD client. + + :param host: MPD server IP|FQDN to connect to + :param port: MPD port to connect to + :param password: MPD password + + **musicpdaio** tries to come with sane defaults, then running + |mpdaio.MPDClient| with no explicit argument will try default values + to connect to MPD. Cf. :ref:`reference` for more about + :ref:`defaults`. + + The class is also exposed in mpdaio namespace. + + >>> import mpdaio + >>> cli = mpdaio.MPDClient(host='example.org') + >>> print(await cli.currentsong()) + >>> await cli.close() + """ + + def __init__(self, host: str | None = None, + port: str | int | None = None, + password: str | None = None): + #: Connection pool + self._pool = ConnectionPool(max_connections=CONNECTION_MAX) + #: connection timeout + self.mpd_timeout = CONNECTION_TIMEOUT + self._get_envvars() + #: Host used to make connections (:py:obj:`str`) + self.host = host or self.server_discovery[0] + #: password used to connect (:py:obj:`str`) + self.pwd = password or self.server_discovery[2] + #: port used with the current connection (:py:obj:`int`, :py:obj:`str`) + self.port = port or self.server_discovery[1] + log.info('Using %s:%s to connect', self.host, self.port) - def __init__(self,): + def _get_envvars(self): + """ + Retrieve MPD env. var. to overrides default "localhost:6600" + """ + # Set some defaults + disco_host = 'localhost' + disco_port = os.getenv('MPD_PORT', '6600') + pwd = None + _host = os.getenv('MPD_HOST', '') + if _host: + # If password is set: MPD_HOST=pass@host + if '@' in _host: + mpd_host_env = _host.split('@', 1) + if mpd_host_env[0]: + # A password is actually set + log.debug( + 'password detected in MPD_HOST, set client pwd attribute') + pwd = mpd_host_env[0] + if mpd_host_env[1]: + disco_host = mpd_host_env[1] + log.debug('host detected in MPD_HOST: %s', disco_host) + elif mpd_host_env[1]: + # No password set but leading @ is an abstract socket + disco_host = '@'+mpd_host_env[1] + log.debug( + 'host detected in MPD_HOST: %s (abstract socket)', disco_host) + else: + # MPD_HOST is a plain host + disco_host = _host + log.debug('host detected in MPD_HOST: %s', disco_host) + else: + # Is socket there + xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR', '/run') + rundir = os.path.join(xdg_runtime_dir, 'mpd/socket') + if os.path.exists(rundir): + disco_host = rundir + log.debug( + 'host detected in ${XDG_RUNTIME_DIR}/run: %s (unix socket)', disco_host) + _mpd_timeout = os.getenv('MPD_TIMEOUT', '') + if _mpd_timeout.isdigit(): + self.mpd_timeout = int(_mpd_timeout) + log.debug('timeout detected in MPD_TIMEOUT: %d', self.mpd_timeout) + else: # Use CONNECTION_TIMEOUT as default even if MPD_TIMEOUT carries gargage + self.mpd_timeout = CONNECTION_TIMEOUT + self.server_discovery = (disco_host, disco_port, pwd) + + def __getattr__(self, attr): + command = attr + wrapper = CmdHandler(self._pool, self.host, self.port, self.pwd, self.mpd_timeout) + if command not in wrapper._commands: + command = command.replace("_", " ") + if command not in wrapper._commands: + raise AttributeError( + f"'CmdHandler' object has no attribute '{attr}'") + return lambda *args: wrapper(command, args) + + @property + def version(self) -> str: + """MPD protocol version + """ + host = (self.host, self.port) + version = {_.version for _ in self.connections} + if not version: + log.warning('No connections yet in the connections pool for %s', host) + return '' + if len(version) > 1: + log.warning('More than one version in the connections pool for %s', host) + return version.pop() + + @property + def connections(self) -> list[Connection]: + """connections in the pool""" + host = (self.host, self.port) + return self._pool._connections.get(host, []) + + async def close(self) -> None: + """:synopsis: Close connections in the pool""" + await self._pool.close() + + +class CmdHandler: + + def __init__(self, pool, server, port, password, timeout): self._commands = { # Status Commands "clearerror": self._fetch_nothing, "currentsong": self._fetch_object, "idle": self._fetch_list, - # "noidle": None, + "noidle": self._fetch_nothing, "status": self._fetch_object, "stats": self._fetch_object, # Playback Option Commands @@ -162,105 +266,56 @@ class MPDClient: "readmessages": self._fetch_messages, "sendmessage": self._fetch_nothing, } - #: host used with the current connection (:py:obj:`str`) - self.host = None - #: password detected in :envvar:`MPD_HOST` environment variable (:py:obj:`str`) - self.pwd = None - #: port used with the current connection (:py:obj:`int`, :py:obj:`str`) - self.port = None - self._get_envvars() - self._pool = ConnectionPool(max_connections=CONNECTION_MAX) - log.info('logger : "%s"', __name__) + self.command = None + self._command_list: list | None = None + self.args = None + self.pool = pool + self.host = (server, port) + self.password = password + self.timeout = timeout #: current connection - self.connection: [None,Connection] = None - #: Protocol version - self.version: [None,str] = None - self._command_list = None + self.connection: [None, Connection] = None + + def __repr__(self): + args = [str(_) for _ in self.args] + args = ','.join(args or []) + return f'{self.command}({args})' + + async def __call__(self, command: str, args: list | None): + server, port = self.host + self.command = command + self.args = args or '' + self.connection = await self.pool.connect(server, port, self.timeout) + async with self.connection: + await self._init_connection() + retval = self._commands[command] + await self._write_command(command, args) + if callable(retval): + return await retval() + return retval - def _get_envvars(self): - """ - Retrieve MPD env. var. to overrides default "localhost:6600" - """ - # Set some defaults - self.host = 'localhost' - self.port = os.getenv('MPD_PORT', '6600') - _host = os.getenv('MPD_HOST', '') - if _host: - # If password is set: MPD_HOST=pass@host - if '@' in _host: - mpd_host_env = _host.split('@', 1) - if mpd_host_env[0]: - # A password is actually set - log.debug( - 'password detected in MPD_HOST, set client pwd attribute') - self.pwd = mpd_host_env[0] - if mpd_host_env[1]: - self.host = mpd_host_env[1] - log.debug('host detected in MPD_HOST: %s', self.host) - elif mpd_host_env[1]: - # No password set but leading @ is an abstract socket - self.host = '@'+mpd_host_env[1] - log.debug( - 'host detected in MPD_HOST: %s (abstract socket)', self.host) - else: - # MPD_HOST is a plain host - self.host = _host - log.debug('host detected in MPD_HOST: %s', self.host) - else: - # Is socket there - xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR', '/run') - rundir = os.path.join(xdg_runtime_dir, 'mpd/socket') - if os.path.exists(rundir): - self.host = rundir - log.debug( - 'host detected in ${XDG_RUNTIME_DIR}/run: %s (unix socket)', self.host) - _mpd_timeout = os.getenv('MPD_TIMEOUT', '') - if _mpd_timeout.isdigit(): - self.mpd_timeout = int(_mpd_timeout) - log.debug('timeout detected in MPD_TIMEOUT: %d', self.mpd_timeout) - else: # Use CONNECTION_TIMEOUT as default even if MPD_TIMEOUT carries gargage - self.mpd_timeout = CONNECTION_TIMEOUT + async def _init_connection(self): + """Init connection if needed - def __getattr__(self, attr): - # if attr == 'send_noidle': # have send_noidle to cancel idle as well as noidle - # return self.noidle - if attr.startswith("send_"): - command = attr.replace("send_", "", 1) - wrapper = self._send - elif attr.startswith("fetch_"): - command = attr.replace("fetch_", "", 1) - wrapper = self._fetch - else: - command = attr - wrapper = self._execute - if command not in self._commands: - command = command.replace("_", " ") - if command not in self._commands: - cls = self.__class__.__name__ - raise AttributeError( - f"'{cls}' object has no attribute '{attr}'") - return lambda *args: wrapper(command, args) + * Consumes the hello line and sets the protocol version + * Send password command if a password is provided + """ + if not self.connection.version: + await self._hello() + if self.password and not self.connection.auth: + # Need to send password + await self._write_command('password', [self.password]) + await self._fetch_nothing() + self.connection.auth = True - async def _execute(self, command, args): # pylint: disable=unused-argument - self.connection = await self._pool.connect(self.host, self.port) - async with self.connection: - # if self._pending: - # raise MPDCommandError( - # f"Cannot execute '{command}' with pending commands") - retval = self._commands[command] - if self._command_list is not None: - if not callable(retval): - raise MPDCommandError( - f"'{command}' not allowed in command list") - self._write_command(command, args) - self._command_list.append(retval) - else: - await self._write_command(command, args) - if callable(retval): - log.debug('retvat: %s', retval) - return await retval() - return retval - return None + async def _hello(self) -> None: + """Consume HELLO_PREFIX""" + data = await self.connection.readuntil(b'\n') + rcv = data.decode('utf-8') + if not rcv.startswith(HELLO_PREFIX): + raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"') + log.debug('consumed hello prefix: %r', rcv) + self.connection.version = rcv.split('\n')[0][len(HELLO_PREFIX):] async def _write_line(self, line): self.connection.write(f"{line!s}\n".encode()) @@ -277,28 +332,26 @@ class MPDClient: parts.append(f'"{escape(str(arg))}"') if '\n' in ' '.join(parts): raise MPDCommandError('new line found in the command!') + #log.debug(' '.join(parts)) await self._write_line(' '.join(parts)) - def _read_binary(self, amount): + async def _read_binary(self, amount): chunk = bytearray() while amount > 0: - result = self._rbfile.read(amount) + result = await self.connection.read(amount) if len(result) == 0: - self.disconnect() + await self.connection.close() raise ConnectionError( "Connection lost while reading binary content") chunk.extend(result) amount -= len(result) return bytes(chunk) - async def _read_line(self, binary=False): - if binary: - line = self._rbfile.readline().decode('utf-8') - else: - line = await self.connection.readline() + async def _read_line(self): + line = await self.connection.readline() line = line.decode('utf-8') if not line.endswith('\n'): - await self.close() + await self.connection.close() raise MPDConnectionError("Connection lost while reading line") line = line.rstrip('\n') if line.startswith(ERROR_PREFIX): @@ -313,8 +366,8 @@ class MPDClient: return None return line - async def _read_pair(self, separator, binary=False): - line = await self._read_line(binary=binary) + async def _read_pair(self, separator): + line = await self._read_line() if line is None: return None pair = line.split(separator, 1) @@ -322,24 +375,24 @@ class MPDClient: raise MPDProtocolError(f"Could not parse pair: '{line}'") return pair - async def _read_pairs(self, separator=": ", binary=False): - """OK""" - pair = await self._read_pair(separator, binary=binary) + async def _read_pairs(self, separator=": "): + pair = await self._read_pair(separator) while pair: yield pair - pair = await self._read_pair(separator, binary=binary) + pair = await self._read_pair(separator) async def _read_list(self): seen = None - for key, value in await self._read_pairs(): + async for key, value in self._read_pairs(): if key != seen: if seen is not None: - raise MPDProtocolError(f"Expected key '{seen}', got '{key}'") + raise MPDProtocolError( + f"Expected key '{seen}', got '{key}'") seen = key yield value async def _read_playlist(self): - for _, value in await self._read_pairs(":"): + async for _, value in self._read_pairs(":"): yield value async def _read_objects(self, delimiters=None): @@ -362,30 +415,30 @@ class MPDClient: if obj: yield obj - def _read_command_list(self): + async def _read_command_list(self): try: for retval in self._command_list: yield retval() finally: self._command_list = None - self._fetch_nothing() + await self._fetch_nothing() async def _fetch_nothing(self): line = await self._read_line() if line is not None: - raise ProtocolError(f"Got unexpected return value: '{line}'") + raise MPDProtocolError(f"Got unexpected return value: '{line}'") async def _fetch_item(self): - pairs = list(await self._read_pairs()) + pairs = [_ async for _ in self._read_pairs()] if len(pairs) != 1: return None return pairs[0][1] - def _fetch_list(self): - return self._read_list() + async def _fetch_list(self): + return [_ async for _ in self._read_list()] - def _fetch_playlist(self): - return self._read_playlist() + async def _fetch_playlist(self): + return [_ async for _ in self._read_pairs(':')] async def _fetch_object(self): objs = [obj async for obj in self._read_objects()] @@ -396,36 +449,36 @@ class MPDClient: async def _fetch_objects(self, delimiters): return [_ async for _ in self._read_objects(delimiters)] - def _fetch_changes(self): - return self._fetch_objects(["cpos"]) + async def _fetch_changes(self): + return await self._fetch_objects(["cpos"]) async def _fetch_songs(self): return await self._fetch_objects(["file"]) - def _fetch_playlists(self): - return self._fetch_objects(["playlist"]) + async def _fetch_playlists(self): + return await self._fetch_objects(["playlist"]) - def _fetch_database(self): - return self._fetch_objects(["file", "directory", "playlist"]) + async def _fetch_database(self): + return await self._fetch_objects(["file", "directory", "playlist"]) - def _fetch_outputs(self): - return self._fetch_objects(["outputid"]) + async def _fetch_outputs(self): + return await self._fetch_objects(["outputid"]) - def _fetch_plugins(self): - return self._fetch_objects(["plugin"]) + async def _fetch_plugins(self): + return await self._fetch_objects(["plugin"]) - def _fetch_messages(self): - return self._fetch_objects(["channel"]) + async def _fetch_messages(self): + return await self._fetch_objects(["channel"]) - def _fetch_mounts(self): - return self._fetch_objects(["mount"]) + async def _fetch_mounts(self): + return await self._fetch_objects(["mount"]) - def _fetch_neighbors(self): - return self._fetch_objects(["neighbor"]) + async def _fetch_neighbors(self): + return await self._fetch_objects(["neighbor"]) async def _fetch_composite(self): obj = {} - for key, value in self._read_pairs(binary=True): + async for key, value in self._read_pairs(): key = key.lower() obj[key] = value if key == 'binary': @@ -436,7 +489,7 @@ class MPDClient: return obj amount = int(obj['binary']) try: - obj['data'] = self._read_binary(amount) + obj['data'] = await self._read_binary(amount) except IOError as err: raise ConnectionError( f'Error reading binary content: {err}') from err @@ -445,36 +498,12 @@ class MPDClient: raise ConnectionError('Error reading binary content: ' f'Expects {amount}B, got {data_bytes}') # Fetches trailing new line - await self._read_line(binary=True) + await self._read_line() + #ALT: await self.connection.readuntil(b'\n') # Fetches SUCCESS code - await self._read_line(binary=True) + await self._read_line() + #ALT: await self.connection.readuntil(b'OK\n') return obj - def _fetch_command_list(self): - return self._read_command_list() - - async def _hello(self): - """Consume HELLO_PREFIX""" - data = await self.connection.readuntil(b'\n') - rcv = data.decode('utf-8') - if not rcv.startswith(HELLO_PREFIX): - raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"') - log.debug('consumed hello prefix: %r', rcv) - self.version = rcv.split('\n')[0][len(HELLO_PREFIX):] - log.info('protocol version: %s', self.version) - - async def connect(self, server=None, port=None) -> Connection: - if not server: - server = self.host - if not port: - port = self.port - try: - self.connection = await self._pool.connect(server, port) - except (ValueError, OSError) as err: - #TODO: Is this the right way to raise Excep - raise MPDConnectionError(err) from err - async with self.connection: - await self._hello() - - async def close(self): - await self.connection.close() + async def _fetch_command_list(self): + return [_ async for _ in self._read_command_list()]