X-Git-Url: http://git.kaliko.me/?a=blobdiff_plain;f=mpdaio%2Fconnection.py;fp=mpdaio%2Fconnection.py;h=32ad65f8363b5de84359baf17f40cada19764366;hb=c4cd26737adc648cb948ab76565349ba25d3ce4e;hp=0000000000000000000000000000000000000000;hpb=3bb431f9ff65e5ae2c4a554189a6ae23e6c00108;p=python-musicpdaio.git diff --git a/mpdaio/connection.py b/mpdaio/connection.py new file mode 100644 index 0000000..32ad65f --- /dev/null +++ b/mpdaio/connection.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: 2012-2024 kaliko +# SPDX-License-Identifier: LGPL-3.0-or-later +"""https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6 +""" + +import asyncio +import contextlib +import logging + +from collections import OrderedDict +from types import TracebackType +from typing import Any, List, Optional, Tuple, Type + + +try: # Python 3.7 + base = contextlib.AbstractAsyncContextManager +except AttributeError as err: + base = object # type: ignore + +Server = str +Port = int +Host = Tuple[Server, Port] +log = logging.getLogger(__name__) + + +class ConnectionPool(base): + def __init__( + self, + max_connections: int = 1000, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + self.max_connections = max_connections + self._loop = loop or asyncio.get_event_loop() + + self._connections: OrderedDict[Host, + List["Connection"]] = OrderedDict() + self._semaphore = asyncio.Semaphore(max_connections) + + async def connect(self, server: Server, port: Port) -> "Connection": + host = (server, port) + + # enforce the connection limit, releasing connections notifies + # the semaphore to release here + await self._semaphore.acquire() + + connections = self._connections.setdefault(host, []) + log.info('got %s in pool', len(connections)) + # find an un-used connection for this host + connection = next( + (conn for conn in connections if not conn.in_use), None) + if connection is None: + # disconnect the least-recently-used un-used connection to make space + # for a new connection. There will be at least one. + for conns_per_host in reversed(self._connections.values()): + for conn in conns_per_host: + if not conn.in_use: + await conn.close() + break + + reader, writer = await asyncio.open_connection(server, port) + connection = Connection(self, host, reader, writer) + connections.append(connection) + + connection.in_use = True + # move current host to the front as most-recently used + self._connections.move_to_end(host, False) + + return connection + + async def close(self): + """Close all connections""" + connections = [c for cs in self._connections.values() for c in cs] + self._connections = OrderedDict() + for connection in connections: + await connection.close() + + def _remove(self, connection): + conns_for_host = self._connections.get(connection._host) + if not conns_for_host: + return + conns_for_host[:] = [c for c in conns_for_host if c != connection] + + def _notify_release(self): + self._semaphore.release() + + async def __aenter__(self) -> "ConnectionPool": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + await self.close() + + def __del__(self) -> None: + connections = [repr(c) + for cs in self._connections.values() for c in cs] + if not connections: + return + + context = { + "pool": self, + "connections": connections, + "message": "Unclosed connection pool", + } + self._loop.call_exception_handler(context) + + +class Connection(base): + def __init__( + self, + pool: ConnectionPool, + host: Host, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + self._host = host + self._pool = pool + self._reader = reader + self._writer = writer + self._closed = False + self.in_use = False + + def __repr__(self): + host = f"{self._host[0]}:{self._host[1]}" + return f"Connection<{host}>" + + @property + def closed(self): + return self._closed + + def release(self) -> None: + self.in_use = False + self._pool._notify_release() + + async def close(self) -> None: + if self._closed: + return + self._closed = True + self._writer.close() + self._pool._remove(self) + try: + await self._writer.wait_closed() + except AttributeError: # wait_closed is new in 3.7 + pass + + def __getattr__(self, name: str) -> Any: + """All unknown attributes are delegated to the reader and writer""" + if self._closed or not self.in_use: + raise ValueError("Can't use a closed or unacquired connection") + if hasattr(self._reader, name): + return getattr(self._reader, name) + return getattr(self._writer, name) + + async def __aenter__(self) -> "Connection": + if self._closed or not self.in_use: + raise ValueError("Can't use a closed or unacquired connection") + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + self.release() + + def __del__(self) -> None: + if self._closed: + return + context = {"connection": self, "message": "Unclosed connection"} + self._pool._loop.call_exception_handler(context)