]> kaliko git repositories - python-musicpdaio.git/commitdiff
Some improvement on ConnectionPool
authorkaliko <kaliko@azylum.org>
Sat, 2 Mar 2024 17:10:52 +0000 (18:10 +0100)
committerkaliko <kaliko@azylum.org>
Sat, 2 Mar 2024 17:10:52 +0000 (18:10 +0100)
Still have to deal with self.connection shared between different MPD
client command call to be truly concurrent.

mpdaio-object.py
mpdaio-test.py [deleted file]
mpdaio-time.py [new file with mode: 0644]
mpdaio/__init__.py
mpdaio/client.py
mpdaio/connection.py

index 0691bb210cd7788c9f40a2580c52735f9aa77240..56533fa062d3ed2226c8ad539a1a1a83031182a0 100644 (file)
@@ -1,13 +1,26 @@
 import asyncio
+import timeit
 
 from mpdaio.client import MPDClient
 
 async def run_cli():
     cli = MPDClient()
-    await cli.connect()
-    current = await cli.currentsong()
-    print(current)
-    print(await cli.playlistinfo())
+    cli.mpd_timeout = 0.1
+    #await cli.connect()
+    #current = await cli.currentsong()
+    #print(current)
+
+    await cli.currentsong()
+    await cli.playlistinfo()
+    await cli.list('artist')
+    #print(await cli.playlistinfo())
     await cli.close()
 
-asyncio.run(run_cli())
+if __name__ == '__main__':
+    asyncio.run(run_cli())
+    asyncio.run(run_cli())
+    #t = timeit.Timer('asyncio.run(run_cli())', globals=globals())
+    #print(t.autorange())
+    with asyncio.Runner() as runner:
+        runner.run(run_cli())
+        runner.run(run_cli())
diff --git a/mpdaio-test.py b/mpdaio-test.py
deleted file mode 100644 (file)
index 255390e..0000000
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/python3
-
-import logging
-
-from asyncio import run
-
-from mpdaio.connection import ConnectionPool
-from mpdaio.exceptions import MPDProtocolError
-
-
-HELLO_PREFIX = 'OK MPD '
-
-async def _hello(conn):
-    """Consume HELLO_PREFIX"""
-    # await conn.drain()
-    # data = await conn.readline()
-    data = await conn.readuntil(b'\n')
-    rcv = data.decode('utf-8')
-    if not rcv.startswith(HELLO_PREFIX):
-        raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
-    logging.debug('consumed hello prefix')
-    logging.debug('"%s"', rcv)
-    version = rcv.split('\n')[0][len(HELLO_PREFIX):]
-    logging.debug('version: %s', version)
-    return version
-
-
-async def lookup(pool, server, port, query):
-    try:
-        conn = await pool.connect(server, port)
-        logging.info(conn)
-    except (ValueError, OSError):
-        return {}
-
-    async with conn:
-        await _hello(conn)
-        conn.write(query.encode('utf-8'))
-        conn.write(b'\n')
-        await conn.drain()
-        data = await conn.readuntil(b'OK\n')
-        rcv = data.decode('utf-8')
-        logging.info(rcv)
-    await pool.close()
-
-
-def main():
-    logging.basicConfig(level=logging.DEBUG)
-    pool = ConnectionPool(max_connections=10)
-    logging.info(pool)
-    run(lookup(pool, 'hispaniola.lan', 6600,'currentsong'))
-
-
-if __name__ == '__main__':
-    main()
diff --git a/mpdaio-time.py b/mpdaio-time.py
new file mode 100644 (file)
index 0000000..1e21080
--- /dev/null
@@ -0,0 +1,71 @@
+import asyncio
+import logging
+import timeit
+
+from mpdaio.client import MPDClient
+from musicpd import MPDClient as MPDClientNAIO
+
+logging.basicConfig(level=logging.DEBUG,
+                    format='%(levelname)-8s %(module)-10s %(message)s')
+
+async def run_cli():
+    cli = MPDClient()
+    cli.mpd_timeout = 0.1
+    #current = await cli.currentsong()
+    #print(current)
+
+    #await cli.connect(host='kaliko.me', port='6601')
+    cli = MPDClient(host='kaliko.me', port=6601)
+    cli.mpd_timeout = 0.1
+    print(await cli.currentsong())
+    print(await cli.playlistinfo())
+    await cli.list('artist')
+    #print(await cli.playlistinfo())
+    await cli.close()
+
+
+async def aio():
+    cli = MPDClient(host='kaliko.me', port=6601)
+    # Group tasks together
+    try:
+        await asyncio.gather(
+            cli.currentsong(),
+            # cli.playlistinfo(),
+            # cli.list('artist'),
+            # cli.listallinfo('The Doors'),
+            # cli.listallinfo('AFX')
+            )
+        # await asyncio.gather(
+        #     cli.currentsong()
+        #     )
+    finally:
+        # finally close
+        await cli.close()
+
+
+def noaio():
+    cli = MPDClientNAIO()
+    cli.mpd_timeout = 0.1
+    cli.connect(host='kaliko.me', port='6601')
+    cli.currentsong()
+    cli.playlistinfo()
+    cli.list('artist')
+    cli.listallinfo('The Doors')
+    cli.listallinfo('AFX')
+    # finally close
+    cli.disconnect()
+
+if __name__ == '__main__':
+    asyncio.run(aio())
+    asyncio.run(run_cli())
+    import sys
+    sys.exit(0)
+    print('Running aio code')
+    t = timeit.Timer('asyncio.run(aio())', globals=globals())
+    #print(t.autorange())
+    print(t.timeit(10))
+    #
+    print('Running non aio code')
+    t = timeit.Timer('noaio()', globals=globals())
+    #print(t.autorange())
+    print(t.timeit(10))
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..18ff0b06d7d91a2ddf2d98f9ada20e5f82a7afae 100644 (file)
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+# SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
+# SPDX-License-Identifier: LGPL-3.0-or-later
+
+HELLO_PREFIX = 'OK MPD '
+ERROR_PREFIX = 'ACK '
+SUCCESS = 'OK'
+NEXT = 'list_OK'
+#: Module version
+VERSION = '0.9.0b2'
+#: Seconds before a connection attempt times out
+#: (overriden by :envvar:`MPD_TIMEOUT` env. var.)
+CONNECTION_TIMEOUT = 30
+#: Socket timeout in second > 0 (Default is :py:obj:`None` for no timeout)
+SOCKET_TIMEOUT = None
+#: Maximum concurrent connections
+CONNECTION_MAX = 100
+
index 190f809f19ba1bb89675c92153ecca1167d8303f..c79009316ce3c9a7407f225f8cf0588659910c6a 100644 (file)
@@ -9,28 +9,15 @@ from .connection import ConnectionPool, Connection
 from .exceptions import MPDConnectionError, MPDProtocolError, MPDCommandError
 from .utils import Range, escape
 
-HELLO_PREFIX = 'OK MPD '
-ERROR_PREFIX = 'ACK '
-SUCCESS = 'OK'
-NEXT = 'list_OK'
-#: Module version
-VERSION = '0.9.0b2'
-#: Seconds before a connection attempt times out
-#: (overriden by :envvar:`MPD_TIMEOUT` env. var.)
-CONNECTION_TIMEOUT = 30
-#: Socket timeout in second > 0 (Default is :py:obj:`None` for no timeout)
-SOCKET_TIMEOUT = None
-#: Maximum concurrent connections
-CONNECTION_MAX = 10
-
-logging.basicConfig(level=logging.DEBUG,
-                    format='%(levelname)-8s %(module)-10s %(message)s')
+from . import CONNECTION_MAX, CONNECTION_TIMEOUT
+from . import ERROR_PREFIX, SUCCESS, NEXT
+
 log = logging.getLogger(__name__)
 
 
 class MPDClient:
 
-    def __init__(self,):
+    def __init__(self, host: str | None = None, port: str | int | None = None, password: str | None = None):
         self._commands = {
             # Status Commands
             "clearerror":         self._fetch_nothing,
@@ -162,28 +149,31 @@ class MPDClient:
             "readmessages":       self._fetch_messages,
             "sendmessage":        self._fetch_nothing,
         }
+        self._get_envvars()
         #: host used with the current connection (:py:obj:`str`)
-        self.host = None
+        self.host = host or self.server_discovery[0]
         #: password detected in :envvar:`MPD_HOST` environment variable (:py:obj:`str`)
-        self.pwd = None
+        self.password = password or self.server_discovery[2]
         #: port used with the current connection (:py:obj:`int`, :py:obj:`str`)
-        self.port = None
-        self._get_envvars()
+        self.port = port or self.server_discovery[1]
+        self._get_envvars()
         self._pool = ConnectionPool(max_connections=CONNECTION_MAX)
         log.info('logger : "%s"', __name__)
         #: current connection
-        self.connection: [None,Connection] = None
+        self.connection: [None, Connection] = None
         #: Protocol version
-        self.version: [None,str] = None
+        self.version: [None, str] = None
         self._command_list = None
+        self.mpd_timeout = CONNECTION_TIMEOUT
 
     def _get_envvars(self):
         """
         Retrieve MPD env. var. to overrides default "localhost:6600"
         """
         # Set some defaults
-        self.host = 'localhost'
-        self.port = os.getenv('MPD_PORT', '6600')
+        disco_host = 'localhost'
+        disco_port = os.getenv('MPD_PORT', '6600')
+        pwd = None
         _host = os.getenv('MPD_HOST', '')
         if _host:
             # If password is set: MPD_HOST=pass@host
@@ -193,33 +183,34 @@ class MPDClient:
                     # A password is actually set
                     log.debug(
                         'password detected in MPD_HOST, set client pwd attribute')
-                    self.pwd = mpd_host_env[0]
+                    pwd = mpd_host_env[0]
                     if mpd_host_env[1]:
-                        self.host = mpd_host_env[1]
-                        log.debug('host detected in MPD_HOST: %s', self.host)
+                        disco_host = mpd_host_env[1]
+                        log.debug('host detected in MPD_HOST: %s', disco_host)
                 elif mpd_host_env[1]:
                     # No password set but leading @ is an abstract socket
-                    self.host = '@'+mpd_host_env[1]
+                    disco_host = '@'+mpd_host_env[1]
                     log.debug(
-                        'host detected in MPD_HOST: %s (abstract socket)', self.host)
+                        'host detected in MPD_HOST: %s (abstract socket)', disco_host)
             else:
                 # MPD_HOST is a plain host
-                self.host = _host
-                log.debug('host detected in MPD_HOST: %s', self.host)
+                disco_host = _host
+                log.debug('host detected in MPD_HOST: %s', disco_host)
         else:
             # Is socket there
             xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR', '/run')
             rundir = os.path.join(xdg_runtime_dir, 'mpd/socket')
             if os.path.exists(rundir):
-                self.host = rundir
+                disco_host = rundir
                 log.debug(
-                    'host detected in ${XDG_RUNTIME_DIR}/run: %s (unix socket)', self.host)
+                    'host detected in ${XDG_RUNTIME_DIR}/run: %s (unix socket)', disco_host)
         _mpd_timeout = os.getenv('MPD_TIMEOUT', '')
         if _mpd_timeout.isdigit():
             self.mpd_timeout = int(_mpd_timeout)
             log.debug('timeout detected in MPD_TIMEOUT: %d', self.mpd_timeout)
         else:  # Use CONNECTION_TIMEOUT as default even if MPD_TIMEOUT carries gargage
             self.mpd_timeout = CONNECTION_TIMEOUT
+        self.server_discovery = (disco_host, disco_port, pwd)
 
     def __getattr__(self, attr):
         # if attr == 'send_noidle':  # have send_noidle to cancel idle as well as noidle
@@ -242,8 +233,10 @@ class MPDClient:
         return lambda *args: wrapper(command, args)
 
     async def _execute(self, command, args):  # pylint: disable=unused-argument
-        self.connection = await self._pool.connect(self.host, self.port)
-        async with self.connection:
+        log.debug(f'#{command}')
+        # self.connection = await self._pool.connect(self.host, self.port, timeout=self.mpd_timeout)
+        # await self._get_connection()
+        async with await self._get_connection():
             # if self._pending:
             #     raise MPDCommandError(
             #         f"Cannot execute '{command}' with pending commands")
@@ -257,7 +250,7 @@ class MPDClient:
             else:
                 await self._write_command(command, args)
                 if callable(retval):
-                    log.debug('retvat: %s', retval)
+                    log.debug('retvat: %s', retval)
                     return await retval()
                 return retval
             return None
@@ -331,10 +324,11 @@ class MPDClient:
 
     async def _read_list(self):
         seen = None
-        for key, value in await self._read_pairs():
+        async for key, value in self._read_pairs():
             if key != seen:
                 if seen is not None:
-                    raise MPDProtocolError(f"Expected key '{seen}', got '{key}'")
+                    raise MPDProtocolError(
+                        f"Expected key '{seen}', got '{key}'")
                 seen = key
             yield value
 
@@ -376,13 +370,13 @@ class MPDClient:
             raise ProtocolError(f"Got unexpected return value: '{line}'")
 
     async def _fetch_item(self):
-        pairs = list(await self._read_pairs())
+        pairs = [_ async for _ in self._read_pairs()]
         if len(pairs) != 1:
             return None
         return pairs[0][1]
 
-    def _fetch_list(self):
-        return self._read_list()
+    async def _fetch_list(self):
+        return [_ async for _ in self._read_list()]
 
     def _fetch_playlist(self):
         return self._read_playlist()
@@ -453,28 +447,17 @@ class MPDClient:
     def _fetch_command_list(self):
         return self._read_command_list()
 
-    async def _hello(self):
-        """Consume HELLO_PREFIX"""
-        data = await self.connection.readuntil(b'\n')
-        rcv = data.decode('utf-8')
-        if not rcv.startswith(HELLO_PREFIX):
-            raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
-        log.debug('consumed hello prefix: %r', rcv)
-        self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
-        log.info('protocol version: %s', self.version)
-
-    async def connect(self, server=None, port=None) -> Connection:
-        if not server:
-            server = self.host
-        if not port:
-            port = self.port
-        try:
-            self.connection = await self._pool.connect(server, port)
-        except (ValueError, OSError) as err:
-            #TODO: Is this the right way to raise Excep
-            raise MPDConnectionError(err) from err
-        async with self.connection:
-            await self._hello()
+    async def _get_connection(self) -> Connection:
+        self.connection = await self._pool.connect(self.host, self.port, timeout=self.mpd_timeout)
+        return self.connection
 
     async def close(self):
-        await self.connection.close()
+        await self._pool.close()
+
+
+class CmdHandler:
+    #TODO: CmdHandler to intanciate in place of MPDClient._execute
+    # The MPDClient.__getattr__ wrapper should instanciate an CmdHandler object
+
+    def __init__(self):
+        pass
index d6f025b26fb63eedab8f989edb3a68362dc7d002..f327a92f109afbc062d6006a31861e064b3df2d2 100644 (file)
@@ -12,6 +12,8 @@ from collections import OrderedDict
 from types import TracebackType
 from typing import Any, List, Optional, Tuple, Type
 
+from . import HELLO_PREFIX
+from .exceptions import MPDProtocolError
 
 try:  # Python 3.7
     base = contextlib.AbstractAsyncContextManager
@@ -37,7 +39,7 @@ class ConnectionPool(base):
                                        List["Connection"]] = OrderedDict()
         self._semaphore = asyncio.Semaphore(max_connections)
 
-    async def connect(self, server: Server, port: Port) -> "Connection":
+    async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
         host = (server, port)
 
         # enforce the connection limit, releasing connections notifies
@@ -58,13 +60,20 @@ class ConnectionPool(base):
                         await conn.close()
                         break
 
-            reader, writer = await asyncio.open_connection(server, port)
+            log.debug('about to connect %s', host)
+            reader, writer = await asyncio.wait_for(
+                    asyncio.open_connection(server, port),
+                    timeout
+                    )
+            log.info('Connected to %s:%s', host[0], host[1])
             connection = Connection(self, host, reader, writer)
+            await connection._hello()
             connections.append(connection)
 
         connection.in_use = True
         # move current host to the front as most-recently used
         self._connections.move_to_end(host, False)
+        log.debug('connection %s in use', connection)
 
         return connection
 
@@ -126,7 +135,7 @@ class Connection(base):
 
     def __repr__(self):
         host = f"{self._host[0]}:{self._host[1]}"
-        return f"Connection<{host}>"
+        return f"Connection<{host}>#{id(self)}"
 
     @property
     def closed(self):
@@ -149,6 +158,18 @@ class Connection(base):
         except AttributeError:  # wait_closed is new in 3.7
             pass
 
+    async def _hello(self) -> None:
+        """Consume HELLO_PREFIX"""
+        self.in_use = True
+        data = await self.readuntil(b'\n')
+        rcv = data.decode('utf-8')
+        if not rcv.startswith(HELLO_PREFIX):
+            raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
+        log.debug('consumed hello prefix: %r', rcv)
+        self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
+        log.info('protocol version: %s', self.version)
+
+
     def __getattr__(self, name: str) -> Any:
         """All unknown attributes are delegated to the reader and writer"""
         if self._closed or not self.in_use: