1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024 kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
17 base = contextlib.AbstractAsyncContextManager
18 except AttributeError as err:
19 base = object # type: ignore
23 Host = Tuple[Server, Port]
24 log = logging.getLogger(__name__)
27 class ConnectionPool(base):
30 max_connections: int = 1000,
31 loop: Optional[asyncio.AbstractEventLoop] = None,
33 self.max_connections = max_connections
34 self._loop = loop or asyncio.get_event_loop()
36 self._connections: OrderedDict[Host,
37 List["Connection"]] = OrderedDict()
38 self._semaphore = asyncio.Semaphore(max_connections)
40 async def connect(self, server: Server, port: Port) -> "Connection":
43 # enforce the connection limit, releasing connections notifies
44 # the semaphore to release here
45 await self._semaphore.acquire()
47 connections = self._connections.setdefault(host, [])
48 log.debug('got %s in pool', len(connections))
49 # find an un-used connection for this host
51 (conn for conn in connections if not conn.in_use), None)
52 if connection is None:
53 # disconnect the least-recently-used un-used connection to make space
54 # for a new connection. There will be at least one.
55 for conns_per_host in reversed(self._connections.values()):
56 for conn in conns_per_host:
61 reader, writer = await asyncio.open_connection(server, port)
62 connection = Connection(self, host, reader, writer)
63 connections.append(connection)
65 connection.in_use = True
66 # move current host to the front as most-recently used
67 self._connections.move_to_end(host, False)
71 async def close(self):
72 """Close all connections"""
73 connections = [c for cs in self._connections.values() for c in cs]
74 self._connections = OrderedDict()
75 for connection in connections:
76 await connection.close()
78 def _remove(self, connection):
79 conns_for_host = self._connections.get(connection._host)
80 if not conns_for_host:
82 conns_for_host[:] = [c for c in conns_for_host if c != connection]
84 def _notify_release(self):
85 self._semaphore.release()
87 async def __aenter__(self) -> "ConnectionPool":
92 exc_type: Optional[Type[BaseException]],
93 exc: Optional[BaseException],
94 tb: Optional[TracebackType],
98 def __del__(self) -> None:
99 connections = [repr(c)
100 for cs in self._connections.values() for c in cs]
106 "connections": connections,
107 "message": "Unclosed connection pool",
109 self._loop.call_exception_handler(context)
112 class Connection(base):
115 pool: ConnectionPool,
117 reader: asyncio.StreamReader,
118 writer: asyncio.StreamWriter,
122 self._reader = reader
123 self._writer = writer
128 host = f"{self._host[0]}:{self._host[1]}"
129 return f"Connection<{host}>"
135 def release(self) -> None:
136 logging.debug('releasing %r', self)
138 self._pool._notify_release()
140 async def close(self) -> None:
143 logging.debug('closing %r', self)
146 self._pool._remove(self)
148 await self._writer.wait_closed()
149 except AttributeError: # wait_closed is new in 3.7
152 def __getattr__(self, name: str) -> Any:
153 """All unknown attributes are delegated to the reader and writer"""
154 if self._closed or not self.in_use:
155 raise ValueError("Can't use a closed or unacquired connection")
156 if hasattr(self._reader, name):
157 return getattr(self._reader, name)
158 return getattr(self._writer, name)
160 async def __aenter__(self) -> "Connection":
161 if self._closed or not self.in_use:
162 raise ValueError("Can't use a closed or unacquired connection")
167 exc_type: Optional[Type[BaseException]],
168 exc: Optional[BaseException],
169 tb: Optional[TracebackType],
173 def __del__(self) -> None:
176 context = {"connection": self, "message": "Unclosed connection"}
177 self._pool._loop.call_exception_handler(context)