+# -*- coding: utf-8 -*-
+# SPDX-FileCopyrightText: 2012-2024 kaliko <kaliko@azylum.org>
+# SPDX-License-Identifier: LGPL-3.0-or-later
+"""https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
+"""
+
+import asyncio
+import contextlib
+import logging
+
+from collections import OrderedDict
+from types import TracebackType
+from typing import Any, List, Optional, Tuple, Type
+
+
+try: # Python 3.7
+ base = contextlib.AbstractAsyncContextManager
+except AttributeError as err:
+ base = object # type: ignore
+
+Server = str
+Port = int
+Host = Tuple[Server, Port]
+log = logging.getLogger(__name__)
+
+
+class ConnectionPool(base):
+ def __init__(
+ self,
+ max_connections: int = 1000,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ ):
+ self.max_connections = max_connections
+ self._loop = loop or asyncio.get_event_loop()
+
+ self._connections: OrderedDict[Host,
+ List["Connection"]] = OrderedDict()
+ self._semaphore = asyncio.Semaphore(max_connections)
+
+ async def connect(self, server: Server, port: Port) -> "Connection":
+ host = (server, port)
+
+ # enforce the connection limit, releasing connections notifies
+ # the semaphore to release here
+ await self._semaphore.acquire()
+
+ connections = self._connections.setdefault(host, [])
+ log.info('got %s in pool', len(connections))
+ # find an un-used connection for this host
+ connection = next(
+ (conn for conn in connections if not conn.in_use), None)
+ if connection is None:
+ # disconnect the least-recently-used un-used connection to make space
+ # for a new connection. There will be at least one.
+ for conns_per_host in reversed(self._connections.values()):
+ for conn in conns_per_host:
+ if not conn.in_use:
+ await conn.close()
+ break
+
+ reader, writer = await asyncio.open_connection(server, port)
+ connection = Connection(self, host, reader, writer)
+ connections.append(connection)
+
+ connection.in_use = True
+ # move current host to the front as most-recently used
+ self._connections.move_to_end(host, False)
+
+ return connection
+
+ async def close(self):
+ """Close all connections"""
+ connections = [c for cs in self._connections.values() for c in cs]
+ self._connections = OrderedDict()
+ for connection in connections:
+ await connection.close()
+
+ def _remove(self, connection):
+ conns_for_host = self._connections.get(connection._host)
+ if not conns_for_host:
+ return
+ conns_for_host[:] = [c for c in conns_for_host if c != connection]
+
+ def _notify_release(self):
+ self._semaphore.release()
+
+ async def __aenter__(self) -> "ConnectionPool":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
+
+ def __del__(self) -> None:
+ connections = [repr(c)
+ for cs in self._connections.values() for c in cs]
+ if not connections:
+ return
+
+ context = {
+ "pool": self,
+ "connections": connections,
+ "message": "Unclosed connection pool",
+ }
+ self._loop.call_exception_handler(context)
+
+
+class Connection(base):
+ def __init__(
+ self,
+ pool: ConnectionPool,
+ host: Host,
+ reader: asyncio.StreamReader,
+ writer: asyncio.StreamWriter,
+ ):
+ self._host = host
+ self._pool = pool
+ self._reader = reader
+ self._writer = writer
+ self._closed = False
+ self.in_use = False
+
+ def __repr__(self):
+ host = f"{self._host[0]}:{self._host[1]}"
+ return f"Connection<{host}>"
+
+ @property
+ def closed(self):
+ return self._closed
+
+ def release(self) -> None:
+ self.in_use = False
+ self._pool._notify_release()
+
+ async def close(self) -> None:
+ if self._closed:
+ return
+ self._closed = True
+ self._writer.close()
+ self._pool._remove(self)
+ try:
+ await self._writer.wait_closed()
+ except AttributeError: # wait_closed is new in 3.7
+ pass
+
+ def __getattr__(self, name: str) -> Any:
+ """All unknown attributes are delegated to the reader and writer"""
+ if self._closed or not self.in_use:
+ raise ValueError("Can't use a closed or unacquired connection")
+ if hasattr(self._reader, name):
+ return getattr(self._reader, name)
+ return getattr(self._writer, name)
+
+ async def __aenter__(self) -> "Connection":
+ if self._closed or not self.in_use:
+ raise ValueError("Can't use a closed or unacquired connection")
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
+ self.release()
+
+ def __del__(self) -> None:
+ if self._closed:
+ return
+ context = {"connection": self, "message": "Unclosed connection"}
+ self._pool._loop.call_exception_handler(context)