]> kaliko git repositories - python-musicpdaio.git/blobdiff - mpdaio/client.py
Fixed unused argument
[python-musicpdaio.git] / mpdaio / client.py
index 7bc483a01a0a5f6cbc418d9d131ac685d36701bc..2b971627e6d604dcd4046f524eb979c1d0cb34b2 100644 (file)
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
+# SPDX-FileCopyrightText: 2008-2010  J. Alexander Treuman <jat@spatialrift.net>
 # SPDX-License-Identifier: LGPL-3.0-or-later
 
 import logging
@@ -9,24 +10,46 @@ from .connection import ConnectionPool, Connection
 from .exceptions import MPDConnectionError, MPDProtocolError, MPDCommandError
 from .utils import Range, escape
 
-from . import CONNECTION_MAX, CONNECTION_TIMEOUT
-from . import ERROR_PREFIX, SUCCESS, NEXT
+from .const import CONNECTION_MAX, CONNECTION_TIMEOUT
+from .const import HELLO_PREFIX, ERROR_PREFIX, SUCCESS, NEXT
 
 log = logging.getLogger(__name__)
 
 
 class MPDClient:
+    """:synopsis: Main class to instanciate building an MPD client.
 
-    def __init__(self, host: str | None = None, port: str | int | None = None, password: str | None = None):
+    :param host: MPD server IP|FQDN to connect to
+    :param port: MPD port to connect to
+    :param password: MPD password
+
+    **musicpdaio** tries to come with sane defaults, then running
+    |mpdaio.MPDClient| with no explicit argument will try default values
+    to connect to MPD. Cf. :ref:`reference` for more about
+    :ref:`defaults<default_settings>`.
+
+    The class is also exposed in mpdaio namespace.
+
+    >>> import mpdaio
+    >>> cli = mpdaio.MPDClient(host='example.org')
+    >>> print(await cli.currentsong())
+    >>> await cli.close()
+    """
+
+    def __init__(self, host: str | None = None,
+                 port: str | int | None = None,
+                 password: str | None = None):
+        #: Connection pool
         self._pool = ConnectionPool(max_connections=CONNECTION_MAX)
+        #: connection timeout
+        self.mpd_timeout = CONNECTION_TIMEOUT
         self._get_envvars()
-        #: host used with the current connection (:py:obj:`str`)
+        #: Host used to make connections (:py:obj:`str`)
         self.host = host or self.server_discovery[0]
-        #: password detected in :envvar:`MPD_HOST` environment variable (:py:obj:`str`)
-        self.password = password or self.server_discovery[2]
+        #: password used to connect (:py:obj:`str`)
+        self.pwd = password or self.server_discovery[2]
         #: port used with the current connection (:py:obj:`int`, :py:obj:`str`)
         self.port = port or self.server_discovery[1]
-        self.mpd_timeout = CONNECTION_TIMEOUT
         log.info('Using %s:%s to connect', self.host, self.port)
 
     def _get_envvars(self):
@@ -77,7 +100,7 @@ class MPDClient:
 
     def __getattr__(self, attr):
         command = attr
-        wrapper = CmdHandler(self._pool, self.host, self.port, self.password, self.mpd_timeout)
+        wrapper = CmdHandler(self._pool, self.host, self.port, self.pwd, self.mpd_timeout)
         if command not in wrapper._commands:
             command = command.replace("_", " ")
             if command not in wrapper._commands:
@@ -86,8 +109,9 @@ class MPDClient:
         return lambda *args: wrapper(command, args)
 
     @property
-    def version(self):
-        """MPD protocol version"""
+    def version(self) -> str:
+        """MPD protocol version
+        """
         host = (self.host, self.port)
         version = {_.version for _ in self.connections}
         if not version:
@@ -98,12 +122,13 @@ class MPDClient:
         return version.pop()
 
     @property
-    def connections(self):
-        """Open connections"""
+    def connections(self) -> list[Connection]:
+        """connections in the pool"""
         host = (self.host, self.port)
         return self._pool._connections.get(host, [])
 
-    async def close(self):
+    async def close(self) -> None:
+        """:synopsis: Close connections in the pool"""
         await self._pool.close()
 
 
@@ -115,7 +140,7 @@ class CmdHandler:
             "clearerror":         self._fetch_nothing,
             "currentsong":        self._fetch_object,
             "idle":               self._fetch_list,
-            # "noidle":             None,
+            "noidle":             self._fetch_nothing,
             "status":             self._fetch_object,
             "stats":              self._fetch_object,
             # Playback Option Commands
@@ -242,7 +267,7 @@ class CmdHandler:
             "sendmessage":        self._fetch_nothing,
         }
         self.command = None
-        self._command_list = None
+        self._command_list: list | None = None
         self.args = None
         self.pool = pool
         self.host = (server, port)
@@ -260,14 +285,38 @@ class CmdHandler:
         server, port = self.host
         self.command = command
         self.args = args or ''
-        self.connection = await self.pool.connect(server, port, timeout=self.timeout)
+        self.connection = await self.pool.connect(server, port, self.timeout)
         async with self.connection:
+            await self._init_connection()
             retval = self._commands[command]
             await self._write_command(command, args)
             if callable(retval):
                 return await retval()
             return retval
 
+    async def _init_connection(self):
+        """Init connection if needed
+
+        * Consumes the hello line and sets the protocol version
+        * Send password command if a password is provided
+        """
+        if not self.connection.version:
+            await self._hello()
+        if self.password and not self.connection.auth:
+            # Need to send password
+            await self._write_command('password', [self.password])
+            await self._fetch_nothing()
+            self.connection.auth = True
+
+    async def _hello(self) -> None:
+        """Consume HELLO_PREFIX"""
+        data = await self.connection.readuntil(b'\n')
+        rcv = data.decode('utf-8')
+        if not rcv.startswith(HELLO_PREFIX):
+            raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
+        log.debug('consumed hello prefix: %r', rcv)
+        self.connection.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
+
     async def _write_line(self, line):
         self.connection.write(f"{line!s}\n".encode())
         await self.connection.drain()
@@ -298,7 +347,7 @@ class CmdHandler:
             amount -= len(result)
         return bytes(chunk)
 
-    async def _read_line(self, binary=False):
+    async def _read_line(self):
         line = await self.connection.readline()
         line = line.decode('utf-8')
         if not line.endswith('\n'):
@@ -317,8 +366,8 @@ class CmdHandler:
             return None
         return line
 
-    async def _read_pair(self, separator, binary=False):
-        line = await self._read_line(binary=binary)
+    async def _read_pair(self, separator):
+        line = await self._read_line()
         if line is None:
             return None
         pair = line.split(separator, 1)
@@ -326,11 +375,11 @@ class CmdHandler:
             raise MPDProtocolError(f"Could not parse pair: '{line}'")
         return pair
 
-    async def _read_pairs(self, separator=": ", binary=False):
-        pair = await self._read_pair(separator, binary=binary)
+    async def _read_pairs(self, separator=": "):
+        pair = await self._read_pair(separator)
         while pair:
             yield pair
-            pair = await self._read_pair(separator, binary=binary)
+            pair = await self._read_pair(separator)
 
     async def _read_list(self):
         seen = None
@@ -377,7 +426,7 @@ class CmdHandler:
     async def _fetch_nothing(self):
         line = await self._read_line()
         if line is not None:
-            raise ProtocolError(f"Got unexpected return value: '{line}'")
+            raise MPDProtocolError(f"Got unexpected return value: '{line}'")
 
     async def _fetch_item(self):
         pairs = [_ async for _ in self._read_pairs()]
@@ -429,7 +478,7 @@ class CmdHandler:
 
     async def _fetch_composite(self):
         obj = {}
-        async for key, value in self._read_pairs(binary=True):
+        async for key, value in self._read_pairs():
             key = key.lower()
             obj[key] = value
             if key == 'binary':
@@ -449,10 +498,10 @@ class CmdHandler:
             raise ConnectionError('Error reading binary content: '
                                   f'Expects {amount}B, got {data_bytes}')
         # Fetches trailing new line
-        await self._read_line(binary=True)
+        await self._read_line()
         #ALT: await self.connection.readuntil(b'\n')
         # Fetches SUCCESS code
-        await self._read_line(binary=True)
+        await self._read_line()
         #ALT: await self.connection.readuntil(b'OK\n')
         return obj