]> kaliko git repositories - python-musicpdaio.git/blobdiff - mpdaio/connection.py
Add module mpdaio
[python-musicpdaio.git] / mpdaio / connection.py
diff --git a/mpdaio/connection.py b/mpdaio/connection.py
new file mode 100644 (file)
index 0000000..32ad65f
--- /dev/null
@@ -0,0 +1,175 @@
+# -*- coding: utf-8 -*-
+# SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
+# SPDX-License-Identifier: LGPL-3.0-or-later
+"""https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
+"""
+
+import asyncio
+import contextlib
+import logging
+
+from collections import OrderedDict
+from types import TracebackType
+from typing import Any, List, Optional, Tuple, Type
+
+
+try:  # Python 3.7
+    base = contextlib.AbstractAsyncContextManager
+except AttributeError as err:
+    base = object  # type: ignore
+
+Server = str
+Port = int
+Host = Tuple[Server, Port]
+log = logging.getLogger(__name__)
+
+
+class ConnectionPool(base):
+    def __init__(
+        self,
+        max_connections: int = 1000,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+    ):
+        self.max_connections = max_connections
+        self._loop = loop or asyncio.get_event_loop()
+
+        self._connections: OrderedDict[Host,
+                                       List["Connection"]] = OrderedDict()
+        self._semaphore = asyncio.Semaphore(max_connections)
+
+    async def connect(self, server: Server, port: Port) -> "Connection":
+        host = (server, port)
+
+        # enforce the connection limit, releasing connections notifies
+        # the semaphore to release here
+        await self._semaphore.acquire()
+
+        connections = self._connections.setdefault(host, [])
+        log.info('got %s in pool', len(connections))
+        # find an un-used connection for this host
+        connection = next(
+            (conn for conn in connections if not conn.in_use), None)
+        if connection is None:
+            # disconnect the least-recently-used un-used connection to make space
+            # for a new connection. There will be at least one.
+            for conns_per_host in reversed(self._connections.values()):
+                for conn in conns_per_host:
+                    if not conn.in_use:
+                        await conn.close()
+                        break
+
+            reader, writer = await asyncio.open_connection(server, port)
+            connection = Connection(self, host, reader, writer)
+            connections.append(connection)
+
+        connection.in_use = True
+        # move current host to the front as most-recently used
+        self._connections.move_to_end(host, False)
+
+        return connection
+
+    async def close(self):
+        """Close all connections"""
+        connections = [c for cs in self._connections.values() for c in cs]
+        self._connections = OrderedDict()
+        for connection in connections:
+            await connection.close()
+
+    def _remove(self, connection):
+        conns_for_host = self._connections.get(connection._host)
+        if not conns_for_host:
+            return
+        conns_for_host[:] = [c for c in conns_for_host if c != connection]
+
+    def _notify_release(self):
+        self._semaphore.release()
+
+    async def __aenter__(self) -> "ConnectionPool":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
+        await self.close()
+
+    def __del__(self) -> None:
+        connections = [repr(c)
+                       for cs in self._connections.values() for c in cs]
+        if not connections:
+            return
+
+        context = {
+            "pool": self,
+            "connections": connections,
+            "message": "Unclosed connection pool",
+        }
+        self._loop.call_exception_handler(context)
+
+
+class Connection(base):
+    def __init__(
+        self,
+        pool: ConnectionPool,
+        host: Host,
+        reader: asyncio.StreamReader,
+        writer: asyncio.StreamWriter,
+    ):
+        self._host = host
+        self._pool = pool
+        self._reader = reader
+        self._writer = writer
+        self._closed = False
+        self.in_use = False
+
+    def __repr__(self):
+        host = f"{self._host[0]}:{self._host[1]}"
+        return f"Connection<{host}>"
+
+    @property
+    def closed(self):
+        return self._closed
+
+    def release(self) -> None:
+        self.in_use = False
+        self._pool._notify_release()
+
+    async def close(self) -> None:
+        if self._closed:
+            return
+        self._closed = True
+        self._writer.close()
+        self._pool._remove(self)
+        try:
+            await self._writer.wait_closed()
+        except AttributeError:  # wait_closed is new in 3.7
+            pass
+
+    def __getattr__(self, name: str) -> Any:
+        """All unknown attributes are delegated to the reader and writer"""
+        if self._closed or not self.in_use:
+            raise ValueError("Can't use a closed or unacquired connection")
+        if hasattr(self._reader, name):
+            return getattr(self._reader, name)
+        return getattr(self._writer, name)
+
+    async def __aenter__(self) -> "Connection":
+        if self._closed or not self.in_use:
+            raise ValueError("Can't use a closed or unacquired connection")
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
+        self.release()
+
+    def __del__(self) -> None:
+        if self._closed:
+            return
+        context = {"connection": self, "message": "Unclosed connection"}
+        self._pool._loop.call_exception_handler(context)