X-Git-Url: https://git.kaliko.me/?p=python-musicpdaio.git;a=blobdiff_plain;f=mpdaio%2Fclient.py;h=2b971627e6d604dcd4046f524eb979c1d0cb34b2;hp=a8ae0cd8d0819275e96c7614c872594d31ee96f0;hb=HEAD;hpb=d6409a4f0fc96ece23aa441516e21a85418fa3b7 diff --git a/mpdaio/client.py b/mpdaio/client.py index a8ae0cd..2b97162 100644 --- a/mpdaio/client.py +++ b/mpdaio/client.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # SPDX-FileCopyrightText: 2012-2024 kaliko +# SPDX-FileCopyrightText: 2008-2010 J. Alexander Treuman # SPDX-License-Identifier: LGPL-3.0-or-later import logging @@ -9,27 +10,47 @@ from .connection import ConnectionPool, Connection from .exceptions import MPDConnectionError, MPDProtocolError, MPDCommandError from .utils import Range, escape -from . import CONNECTION_MAX, CONNECTION_TIMEOUT -from . import ERROR_PREFIX, SUCCESS, NEXT +from .const import CONNECTION_MAX, CONNECTION_TIMEOUT +from .const import HELLO_PREFIX, ERROR_PREFIX, SUCCESS, NEXT log = logging.getLogger(__name__) class MPDClient: + """:synopsis: Main class to instanciate building an MPD client. - def __init__(self, host: str | None = None, port: str | int | None = None, password: str | None = None): + :param host: MPD server IP|FQDN to connect to + :param port: MPD port to connect to + :param password: MPD password + + **musicpdaio** tries to come with sane defaults, then running + |mpdaio.MPDClient| with no explicit argument will try default values + to connect to MPD. Cf. :ref:`reference` for more about + :ref:`defaults`. + + The class is also exposed in mpdaio namespace. + + >>> import mpdaio + >>> cli = mpdaio.MPDClient(host='example.org') + >>> print(await cli.currentsong()) + >>> await cli.close() + """ + + def __init__(self, host: str | None = None, + port: str | int | None = None, + password: str | None = None): + #: Connection pool self._pool = ConnectionPool(max_connections=CONNECTION_MAX) + #: connection timeout + self.mpd_timeout = CONNECTION_TIMEOUT self._get_envvars() - #: host used with the current connection (:py:obj:`str`) + #: Host used to make connections (:py:obj:`str`) self.host = host or self.server_discovery[0] - #: password detected in :envvar:`MPD_HOST` environment variable (:py:obj:`str`) - self.password = password or self.server_discovery[2] + #: password used to connect (:py:obj:`str`) + self.pwd = password or self.server_discovery[2] #: port used with the current connection (:py:obj:`int`, :py:obj:`str`) self.port = port or self.server_discovery[1] - log.info('logger : "%s"', __name__) - #: Protocol version - self.version: [None, str] = None - self.mpd_timeout = CONNECTION_TIMEOUT + log.info('Using %s:%s to connect', self.host, self.port) def _get_envvars(self): """ @@ -79,7 +100,7 @@ class MPDClient: def __getattr__(self, attr): command = attr - wrapper = CmdHandler(self._pool, self.host, self.port, self.password, self.mpd_timeout) + wrapper = CmdHandler(self._pool, self.host, self.port, self.pwd, self.mpd_timeout) if command not in wrapper._commands: command = command.replace("_", " ") if command not in wrapper._commands: @@ -87,13 +108,31 @@ class MPDClient: f"'CmdHandler' object has no attribute '{attr}'") return lambda *args: wrapper(command, args) - async def close(self): + @property + def version(self) -> str: + """MPD protocol version + """ + host = (self.host, self.port) + version = {_.version for _ in self.connections} + if not version: + log.warning('No connections yet in the connections pool for %s', host) + return '' + if len(version) > 1: + log.warning('More than one version in the connections pool for %s', host) + return version.pop() + + @property + def connections(self) -> list[Connection]: + """connections in the pool""" + host = (self.host, self.port) + return self._pool._connections.get(host, []) + + async def close(self) -> None: + """:synopsis: Close connections in the pool""" await self._pool.close() class CmdHandler: - #TODO: CmdHandler to intanciate in place of MPDClient._execute - # The MPDClient.__getattr__ wrapper should instanciate an CmdHandler object def __init__(self, pool, server, port, password, timeout): self._commands = { @@ -101,7 +140,7 @@ class CmdHandler: "clearerror": self._fetch_nothing, "currentsong": self._fetch_object, "idle": self._fetch_list, - # "noidle": None, + "noidle": self._fetch_nothing, "status": self._fetch_object, "stats": self._fetch_object, # Playback Option Commands @@ -228,7 +267,7 @@ class CmdHandler: "sendmessage": self._fetch_nothing, } self.command = None - self._command_list = None + self._command_list: list | None = None self.args = None self.pool = pool self.host = (server, port) @@ -246,14 +285,38 @@ class CmdHandler: server, port = self.host self.command = command self.args = args or '' - self.connection = await self.pool.connect(server, port, timeout=self.timeout) + self.connection = await self.pool.connect(server, port, self.timeout) async with self.connection: + await self._init_connection() retval = self._commands[command] await self._write_command(command, args) if callable(retval): return await retval() return retval + async def _init_connection(self): + """Init connection if needed + + * Consumes the hello line and sets the protocol version + * Send password command if a password is provided + """ + if not self.connection.version: + await self._hello() + if self.password and not self.connection.auth: + # Need to send password + await self._write_command('password', [self.password]) + await self._fetch_nothing() + self.connection.auth = True + + async def _hello(self) -> None: + """Consume HELLO_PREFIX""" + data = await self.connection.readuntil(b'\n') + rcv = data.decode('utf-8') + if not rcv.startswith(HELLO_PREFIX): + raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"') + log.debug('consumed hello prefix: %r', rcv) + self.connection.version = rcv.split('\n')[0][len(HELLO_PREFIX):] + async def _write_line(self, line): self.connection.write(f"{line!s}\n".encode()) await self.connection.drain() @@ -269,29 +332,26 @@ class CmdHandler: parts.append(f'"{escape(str(arg))}"') if '\n' in ' '.join(parts): raise MPDCommandError('new line found in the command!') - log.debug(' '.join(parts)) + #log.debug(' '.join(parts)) await self._write_line(' '.join(parts)) - def _read_binary(self, amount): + async def _read_binary(self, amount): chunk = bytearray() while amount > 0: - result = self._rbfile.read(amount) + result = await self.connection.read(amount) if len(result) == 0: - self.disconnect() + await self.connection.close() raise ConnectionError( "Connection lost while reading binary content") chunk.extend(result) amount -= len(result) return bytes(chunk) - async def _read_line(self, binary=False): - if binary: - line = self._rbfile.readline().decode('utf-8') - else: - line = await self.connection.readline() + async def _read_line(self): + line = await self.connection.readline() line = line.decode('utf-8') if not line.endswith('\n'): - await self.close() + await self.connection.close() raise MPDConnectionError("Connection lost while reading line") line = line.rstrip('\n') if line.startswith(ERROR_PREFIX): @@ -306,8 +366,8 @@ class CmdHandler: return None return line - async def _read_pair(self, separator, binary=False): - line = await self._read_line(binary=binary) + async def _read_pair(self, separator): + line = await self._read_line() if line is None: return None pair = line.split(separator, 1) @@ -315,12 +375,11 @@ class CmdHandler: raise MPDProtocolError(f"Could not parse pair: '{line}'") return pair - async def _read_pairs(self, separator=": ", binary=False): - """OK""" - pair = await self._read_pair(separator, binary=binary) + async def _read_pairs(self, separator=": "): + pair = await self._read_pair(separator) while pair: yield pair - pair = await self._read_pair(separator, binary=binary) + pair = await self._read_pair(separator) async def _read_list(self): seen = None @@ -367,7 +426,7 @@ class CmdHandler: async def _fetch_nothing(self): line = await self._read_line() if line is not None: - raise ProtocolError(f"Got unexpected return value: '{line}'") + raise MPDProtocolError(f"Got unexpected return value: '{line}'") async def _fetch_item(self): pairs = [_ async for _ in self._read_pairs()] @@ -419,7 +478,7 @@ class CmdHandler: async def _fetch_composite(self): obj = {} - for key, value in self._read_pairs(binary=True): + async for key, value in self._read_pairs(): key = key.lower() obj[key] = value if key == 'binary': @@ -430,7 +489,7 @@ class CmdHandler: return obj amount = int(obj['binary']) try: - obj['data'] = self._read_binary(amount) + obj['data'] = await self._read_binary(amount) except IOError as err: raise ConnectionError( f'Error reading binary content: {err}') from err @@ -439,9 +498,11 @@ class CmdHandler: raise ConnectionError('Error reading binary content: ' f'Expects {amount}B, got {data_bytes}') # Fetches trailing new line - await self._read_line(binary=True) + await self._read_line() + #ALT: await self.connection.readuntil(b'\n') # Fetches SUCCESS code - await self._read_line(binary=True) + await self._read_line() + #ALT: await self.connection.readuntil(b'OK\n') return obj async def _fetch_command_list(self):