]> kaliko git repositories - python-musicpdaio.git/blobdiff - mpdaio/client.py
Fixed unused argument
[python-musicpdaio.git] / mpdaio / client.py
index 190f809f19ba1bb89675c92153ecca1167d8303f..2b971627e6d604dcd4046f524eb979c1d0cb34b2 100644 (file)
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
+# SPDX-FileCopyrightText: 2008-2010  J. Alexander Treuman <jat@spatialrift.net>
 # SPDX-License-Identifier: LGPL-3.0-or-later
 
 import logging
@@ -9,34 +10,137 @@ from .connection import ConnectionPool, Connection
 from .exceptions import MPDConnectionError, MPDProtocolError, MPDCommandError
 from .utils import Range, escape
 
-HELLO_PREFIX = 'OK MPD '
-ERROR_PREFIX = 'ACK '
-SUCCESS = 'OK'
-NEXT = 'list_OK'
-#: Module version
-VERSION = '0.9.0b2'
-#: Seconds before a connection attempt times out
-#: (overriden by :envvar:`MPD_TIMEOUT` env. var.)
-CONNECTION_TIMEOUT = 30
-#: Socket timeout in second > 0 (Default is :py:obj:`None` for no timeout)
-SOCKET_TIMEOUT = None
-#: Maximum concurrent connections
-CONNECTION_MAX = 10
-
-logging.basicConfig(level=logging.DEBUG,
-                    format='%(levelname)-8s %(module)-10s %(message)s')
+from .const import CONNECTION_MAX, CONNECTION_TIMEOUT
+from .const import HELLO_PREFIX, ERROR_PREFIX, SUCCESS, NEXT
+
 log = logging.getLogger(__name__)
 
 
 class MPDClient:
+    """:synopsis: Main class to instanciate building an MPD client.
+
+    :param host: MPD server IP|FQDN to connect to
+    :param port: MPD port to connect to
+    :param password: MPD password
+
+    **musicpdaio** tries to come with sane defaults, then running
+    |mpdaio.MPDClient| with no explicit argument will try default values
+    to connect to MPD. Cf. :ref:`reference` for more about
+    :ref:`defaults<default_settings>`.
+
+    The class is also exposed in mpdaio namespace.
+
+    >>> import mpdaio
+    >>> cli = mpdaio.MPDClient(host='example.org')
+    >>> print(await cli.currentsong())
+    >>> await cli.close()
+    """
+
+    def __init__(self, host: str | None = None,
+                 port: str | int | None = None,
+                 password: str | None = None):
+        #: Connection pool
+        self._pool = ConnectionPool(max_connections=CONNECTION_MAX)
+        #: connection timeout
+        self.mpd_timeout = CONNECTION_TIMEOUT
+        self._get_envvars()
+        #: Host used to make connections (:py:obj:`str`)
+        self.host = host or self.server_discovery[0]
+        #: password used to connect (:py:obj:`str`)
+        self.pwd = password or self.server_discovery[2]
+        #: port used with the current connection (:py:obj:`int`, :py:obj:`str`)
+        self.port = port or self.server_discovery[1]
+        log.info('Using %s:%s to connect', self.host, self.port)
 
-    def __init__(self,):
+    def _get_envvars(self):
+        """
+        Retrieve MPD env. var. to overrides default "localhost:6600"
+        """
+        # Set some defaults
+        disco_host = 'localhost'
+        disco_port = os.getenv('MPD_PORT', '6600')
+        pwd = None
+        _host = os.getenv('MPD_HOST', '')
+        if _host:
+            # If password is set: MPD_HOST=pass@host
+            if '@' in _host:
+                mpd_host_env = _host.split('@', 1)
+                if mpd_host_env[0]:
+                    # A password is actually set
+                    log.debug(
+                        'password detected in MPD_HOST, set client pwd attribute')
+                    pwd = mpd_host_env[0]
+                    if mpd_host_env[1]:
+                        disco_host = mpd_host_env[1]
+                        log.debug('host detected in MPD_HOST: %s', disco_host)
+                elif mpd_host_env[1]:
+                    # No password set but leading @ is an abstract socket
+                    disco_host = '@'+mpd_host_env[1]
+                    log.debug(
+                        'host detected in MPD_HOST: %s (abstract socket)', disco_host)
+            else:
+                # MPD_HOST is a plain host
+                disco_host = _host
+                log.debug('host detected in MPD_HOST: %s', disco_host)
+        else:
+            # Is socket there
+            xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR', '/run')
+            rundir = os.path.join(xdg_runtime_dir, 'mpd/socket')
+            if os.path.exists(rundir):
+                disco_host = rundir
+                log.debug(
+                    'host detected in ${XDG_RUNTIME_DIR}/run: %s (unix socket)', disco_host)
+        _mpd_timeout = os.getenv('MPD_TIMEOUT', '')
+        if _mpd_timeout.isdigit():
+            self.mpd_timeout = int(_mpd_timeout)
+            log.debug('timeout detected in MPD_TIMEOUT: %d', self.mpd_timeout)
+        else:  # Use CONNECTION_TIMEOUT as default even if MPD_TIMEOUT carries gargage
+            self.mpd_timeout = CONNECTION_TIMEOUT
+        self.server_discovery = (disco_host, disco_port, pwd)
+
+    def __getattr__(self, attr):
+        command = attr
+        wrapper = CmdHandler(self._pool, self.host, self.port, self.pwd, self.mpd_timeout)
+        if command not in wrapper._commands:
+            command = command.replace("_", " ")
+            if command not in wrapper._commands:
+                raise AttributeError(
+                    f"'CmdHandler' object has no attribute '{attr}'")
+        return lambda *args: wrapper(command, args)
+
+    @property
+    def version(self) -> str:
+        """MPD protocol version
+        """
+        host = (self.host, self.port)
+        version = {_.version for _ in self.connections}
+        if not version:
+            log.warning('No connections yet in the connections pool for %s', host)
+            return ''
+        if len(version) > 1:
+            log.warning('More than one version in the connections pool for %s', host)
+        return version.pop()
+
+    @property
+    def connections(self) -> list[Connection]:
+        """connections in the pool"""
+        host = (self.host, self.port)
+        return self._pool._connections.get(host, [])
+
+    async def close(self) -> None:
+        """:synopsis: Close connections in the pool"""
+        await self._pool.close()
+
+
+class CmdHandler:
+
+    def __init__(self, pool, server, port, password, timeout):
         self._commands = {
             # Status Commands
             "clearerror":         self._fetch_nothing,
             "currentsong":        self._fetch_object,
             "idle":               self._fetch_list,
-            # "noidle":             None,
+            "noidle":             self._fetch_nothing,
             "status":             self._fetch_object,
             "stats":              self._fetch_object,
             # Playback Option Commands
@@ -162,105 +266,56 @@ class MPDClient:
             "readmessages":       self._fetch_messages,
             "sendmessage":        self._fetch_nothing,
         }
-        #: host used with the current connection (:py:obj:`str`)
-        self.host = None
-        #: password detected in :envvar:`MPD_HOST` environment variable (:py:obj:`str`)
-        self.pwd = None
-        #: port used with the current connection (:py:obj:`int`, :py:obj:`str`)
-        self.port = None
-        self._get_envvars()
-        self._pool = ConnectionPool(max_connections=CONNECTION_MAX)
-        log.info('logger : "%s"', __name__)
+        self.command = None
+        self._command_list: list | None = None
+        self.args = None
+        self.pool = pool
+        self.host = (server, port)
+        self.password = password
+        self.timeout = timeout
         #: current connection
-        self.connection: [None,Connection] = None
-        #: Protocol version
-        self.version: [None,str] = None
-        self._command_list = None
+        self.connection: [None, Connection] = None
+
+    def __repr__(self):
+        args = [str(_) for _ in self.args]
+        args = ','.join(args or [])
+        return f'{self.command}({args})'
+
+    async def __call__(self, command: str, args: list | None):
+        server, port = self.host
+        self.command = command
+        self.args = args or ''
+        self.connection = await self.pool.connect(server, port, self.timeout)
+        async with self.connection:
+            await self._init_connection()
+            retval = self._commands[command]
+            await self._write_command(command, args)
+            if callable(retval):
+                return await retval()
+            return retval
 
-    def _get_envvars(self):
-        """
-        Retrieve MPD env. var. to overrides default "localhost:6600"
-        """
-        # Set some defaults
-        self.host = 'localhost'
-        self.port = os.getenv('MPD_PORT', '6600')
-        _host = os.getenv('MPD_HOST', '')
-        if _host:
-            # If password is set: MPD_HOST=pass@host
-            if '@' in _host:
-                mpd_host_env = _host.split('@', 1)
-                if mpd_host_env[0]:
-                    # A password is actually set
-                    log.debug(
-                        'password detected in MPD_HOST, set client pwd attribute')
-                    self.pwd = mpd_host_env[0]
-                    if mpd_host_env[1]:
-                        self.host = mpd_host_env[1]
-                        log.debug('host detected in MPD_HOST: %s', self.host)
-                elif mpd_host_env[1]:
-                    # No password set but leading @ is an abstract socket
-                    self.host = '@'+mpd_host_env[1]
-                    log.debug(
-                        'host detected in MPD_HOST: %s (abstract socket)', self.host)
-            else:
-                # MPD_HOST is a plain host
-                self.host = _host
-                log.debug('host detected in MPD_HOST: %s', self.host)
-        else:
-            # Is socket there
-            xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR', '/run')
-            rundir = os.path.join(xdg_runtime_dir, 'mpd/socket')
-            if os.path.exists(rundir):
-                self.host = rundir
-                log.debug(
-                    'host detected in ${XDG_RUNTIME_DIR}/run: %s (unix socket)', self.host)
-        _mpd_timeout = os.getenv('MPD_TIMEOUT', '')
-        if _mpd_timeout.isdigit():
-            self.mpd_timeout = int(_mpd_timeout)
-            log.debug('timeout detected in MPD_TIMEOUT: %d', self.mpd_timeout)
-        else:  # Use CONNECTION_TIMEOUT as default even if MPD_TIMEOUT carries gargage
-            self.mpd_timeout = CONNECTION_TIMEOUT
+    async def _init_connection(self):
+        """Init connection if needed
 
-    def __getattr__(self, attr):
-        # if attr == 'send_noidle':  # have send_noidle to cancel idle as well as noidle
-        #     return self.noidle
-        if attr.startswith("send_"):
-            command = attr.replace("send_", "", 1)
-            wrapper = self._send
-        elif attr.startswith("fetch_"):
-            command = attr.replace("fetch_", "", 1)
-            wrapper = self._fetch
-        else:
-            command = attr
-            wrapper = self._execute
-        if command not in self._commands:
-            command = command.replace("_", " ")
-            if command not in self._commands:
-                cls = self.__class__.__name__
-                raise AttributeError(
-                    f"'{cls}' object has no attribute '{attr}'")
-        return lambda *args: wrapper(command, args)
+        * Consumes the hello line and sets the protocol version
+        * Send password command if a password is provided
+        """
+        if not self.connection.version:
+            await self._hello()
+        if self.password and not self.connection.auth:
+            # Need to send password
+            await self._write_command('password', [self.password])
+            await self._fetch_nothing()
+            self.connection.auth = True
 
-    async def _execute(self, command, args):  # pylint: disable=unused-argument
-        self.connection = await self._pool.connect(self.host, self.port)
-        async with self.connection:
-            # if self._pending:
-            #     raise MPDCommandError(
-            #         f"Cannot execute '{command}' with pending commands")
-            retval = self._commands[command]
-            if self._command_list is not None:
-                if not callable(retval):
-                    raise MPDCommandError(
-                        f"'{command}' not allowed in command list")
-                self._write_command(command, args)
-                self._command_list.append(retval)
-            else:
-                await self._write_command(command, args)
-                if callable(retval):
-                    log.debug('retvat: %s', retval)
-                    return await retval()
-                return retval
-            return None
+    async def _hello(self) -> None:
+        """Consume HELLO_PREFIX"""
+        data = await self.connection.readuntil(b'\n')
+        rcv = data.decode('utf-8')
+        if not rcv.startswith(HELLO_PREFIX):
+            raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
+        log.debug('consumed hello prefix: %r', rcv)
+        self.connection.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
 
     async def _write_line(self, line):
         self.connection.write(f"{line!s}\n".encode())
@@ -277,28 +332,26 @@ class MPDClient:
                 parts.append(f'"{escape(str(arg))}"')
         if '\n' in ' '.join(parts):
             raise MPDCommandError('new line found in the command!')
+        #log.debug(' '.join(parts))
         await self._write_line(' '.join(parts))
 
-    def _read_binary(self, amount):
+    async def _read_binary(self, amount):
         chunk = bytearray()
         while amount > 0:
-            result = self._rbfile.read(amount)
+            result = await self.connection.read(amount)
             if len(result) == 0:
-                self.disconnect()
+                await self.connection.close()
                 raise ConnectionError(
                     "Connection lost while reading binary content")
             chunk.extend(result)
             amount -= len(result)
         return bytes(chunk)
 
-    async def _read_line(self, binary=False):
-        if binary:
-            line = self._rbfile.readline().decode('utf-8')
-        else:
-            line = await self.connection.readline()
+    async def _read_line(self):
+        line = await self.connection.readline()
         line = line.decode('utf-8')
         if not line.endswith('\n'):
-            await self.close()
+            await self.connection.close()
             raise MPDConnectionError("Connection lost while reading line")
         line = line.rstrip('\n')
         if line.startswith(ERROR_PREFIX):
@@ -313,8 +366,8 @@ class MPDClient:
             return None
         return line
 
-    async def _read_pair(self, separator, binary=False):
-        line = await self._read_line(binary=binary)
+    async def _read_pair(self, separator):
+        line = await self._read_line()
         if line is None:
             return None
         pair = line.split(separator, 1)
@@ -322,24 +375,24 @@ class MPDClient:
             raise MPDProtocolError(f"Could not parse pair: '{line}'")
         return pair
 
-    async def _read_pairs(self, separator=": ", binary=False):
-        """OK"""
-        pair = await self._read_pair(separator, binary=binary)
+    async def _read_pairs(self, separator=": "):
+        pair = await self._read_pair(separator)
         while pair:
             yield pair
-            pair = await self._read_pair(separator, binary=binary)
+            pair = await self._read_pair(separator)
 
     async def _read_list(self):
         seen = None
-        for key, value in await self._read_pairs():
+        async for key, value in self._read_pairs():
             if key != seen:
                 if seen is not None:
-                    raise MPDProtocolError(f"Expected key '{seen}', got '{key}'")
+                    raise MPDProtocolError(
+                        f"Expected key '{seen}', got '{key}'")
                 seen = key
             yield value
 
     async def _read_playlist(self):
-        for _, value in await self._read_pairs(":"):
+        async for _, value in self._read_pairs(":"):
             yield value
 
     async def _read_objects(self, delimiters=None):
@@ -362,30 +415,30 @@ class MPDClient:
         if obj:
             yield obj
 
-    def _read_command_list(self):
+    async def _read_command_list(self):
         try:
             for retval in self._command_list:
                 yield retval()
         finally:
             self._command_list = None
-        self._fetch_nothing()
+        await self._fetch_nothing()
 
     async def _fetch_nothing(self):
         line = await self._read_line()
         if line is not None:
-            raise ProtocolError(f"Got unexpected return value: '{line}'")
+            raise MPDProtocolError(f"Got unexpected return value: '{line}'")
 
     async def _fetch_item(self):
-        pairs = list(await self._read_pairs())
+        pairs = [_ async for _ in self._read_pairs()]
         if len(pairs) != 1:
             return None
         return pairs[0][1]
 
-    def _fetch_list(self):
-        return self._read_list()
+    async def _fetch_list(self):
+        return [_ async for _ in self._read_list()]
 
-    def _fetch_playlist(self):
-        return self._read_playlist()
+    async def _fetch_playlist(self):
+        return [_ async for _ in self._read_pairs(':')]
 
     async def _fetch_object(self):
         objs = [obj async for obj in self._read_objects()]
@@ -396,36 +449,36 @@ class MPDClient:
     async def _fetch_objects(self, delimiters):
         return [_ async for _ in self._read_objects(delimiters)]
 
-    def _fetch_changes(self):
-        return self._fetch_objects(["cpos"])
+    async def _fetch_changes(self):
+        return await self._fetch_objects(["cpos"])
 
     async def _fetch_songs(self):
         return await self._fetch_objects(["file"])
 
-    def _fetch_playlists(self):
-        return self._fetch_objects(["playlist"])
+    async def _fetch_playlists(self):
+        return await self._fetch_objects(["playlist"])
 
-    def _fetch_database(self):
-        return self._fetch_objects(["file", "directory", "playlist"])
+    async def _fetch_database(self):
+        return await self._fetch_objects(["file", "directory", "playlist"])
 
-    def _fetch_outputs(self):
-        return self._fetch_objects(["outputid"])
+    async def _fetch_outputs(self):
+        return await self._fetch_objects(["outputid"])
 
-    def _fetch_plugins(self):
-        return self._fetch_objects(["plugin"])
+    async def _fetch_plugins(self):
+        return await self._fetch_objects(["plugin"])
 
-    def _fetch_messages(self):
-        return self._fetch_objects(["channel"])
+    async def _fetch_messages(self):
+        return await self._fetch_objects(["channel"])
 
-    def _fetch_mounts(self):
-        return self._fetch_objects(["mount"])
+    async def _fetch_mounts(self):
+        return await self._fetch_objects(["mount"])
 
-    def _fetch_neighbors(self):
-        return self._fetch_objects(["neighbor"])
+    async def _fetch_neighbors(self):
+        return await self._fetch_objects(["neighbor"])
 
     async def _fetch_composite(self):
         obj = {}
-        for key, value in self._read_pairs(binary=True):
+        async for key, value in self._read_pairs():
             key = key.lower()
             obj[key] = value
             if key == 'binary':
@@ -436,7 +489,7 @@ class MPDClient:
             return obj
         amount = int(obj['binary'])
         try:
-            obj['data'] = self._read_binary(amount)
+            obj['data'] = await self._read_binary(amount)
         except IOError as err:
             raise ConnectionError(
                 f'Error reading binary content: {err}') from err
@@ -445,36 +498,12 @@ class MPDClient:
             raise ConnectionError('Error reading binary content: '
                                   f'Expects {amount}B, got {data_bytes}')
         # Fetches trailing new line
-        await self._read_line(binary=True)
+        await self._read_line()
+        #ALT: await self.connection.readuntil(b'\n')
         # Fetches SUCCESS code
-        await self._read_line(binary=True)
+        await self._read_line()
+        #ALT: await self.connection.readuntil(b'OK\n')
         return obj
 
-    def _fetch_command_list(self):
-        return self._read_command_list()
-
-    async def _hello(self):
-        """Consume HELLO_PREFIX"""
-        data = await self.connection.readuntil(b'\n')
-        rcv = data.decode('utf-8')
-        if not rcv.startswith(HELLO_PREFIX):
-            raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
-        log.debug('consumed hello prefix: %r', rcv)
-        self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
-        log.info('protocol version: %s', self.version)
-
-    async def connect(self, server=None, port=None) -> Connection:
-        if not server:
-            server = self.host
-        if not port:
-            port = self.port
-        try:
-            self.connection = await self._pool.connect(server, port)
-        except (ValueError, OSError) as err:
-            #TODO: Is this the right way to raise Excep
-            raise MPDConnectionError(err) from err
-        async with self.connection:
-            await self._hello()
-
-    async def close(self):
-        await self.connection.close()
+    async def _fetch_command_list(self):
+        return [_ async for _ in self._read_command_list()]