1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024 kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
16 base = contextlib.AbstractAsyncContextManager
17 except AttributeError as err:
18 base = object # type: ignore
22 Host = Tuple[Server, Port]
23 log = logging.getLogger(__name__)
26 class ConnectionPool(base):
29 max_connections: int = 1000,
30 loop: Optional[asyncio.AbstractEventLoop] = None,
32 self.max_connections = max_connections
33 self._loop = loop or asyncio.get_event_loop()
35 self._connections: OrderedDict[Host,
36 List["Connection"]] = OrderedDict()
37 self._semaphore = asyncio.Semaphore(max_connections)
39 async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
42 # enforce the connection limit, releasing connections notifies
43 # the semaphore to release here
44 await self._semaphore.acquire()
46 connections = self._connections.setdefault(host, [])
47 log.debug('got %s in pool', len(connections))
48 # find an un-used connection for this host
50 (conn for conn in connections if not conn.in_use), None)
52 # log.debug('reusing %s', connection)
53 if connection is None:
54 # disconnect the least-recently-used un-used connection to make space
55 # for a new connection. There will be at least one.
56 for conns_per_host in reversed(self._connections.values()):
57 for conn in conns_per_host:
61 if server[0] in ['/', '@']:
62 log.debug('about to connect unix socket %s', server)
63 reader, writer = await asyncio.wait_for(
64 asyncio.open_unix_connection(path=server),
68 log.debug('about to connect tcp socket %s:%s', *host)
69 reader, writer = await asyncio.wait_for(
70 asyncio.open_connection(server, port),
73 #log.debug('Connected to %s:%s', host[0], host[1])
74 connection = Connection(self, host, reader, writer)
75 connections.append(connection)
77 connection.in_use = True
78 # move current host to the front as most-recently used
79 self._connections.move_to_end(host, False)
80 log.debug('connection %s in use', connection)
84 async def close(self):
85 """Close all connections"""
86 connections = [c for cs in self._connections.values() for c in cs]
87 self._connections = OrderedDict()
88 log.info('Closing all connections')
89 for connection in connections:
90 await connection.close()
92 def _remove(self, connection):
93 conns_for_host = self._connections.get(connection._host)
94 if not conns_for_host:
96 conns_for_host[:] = [c for c in conns_for_host if c != connection]
98 def _notify_release(self):
99 self._semaphore.release()
101 async def __aenter__(self) -> "ConnectionPool":
106 exc_type: Optional[Type[BaseException]],
107 exc: Optional[BaseException],
108 tb: Optional[TracebackType],
112 def __del__(self) -> None:
113 connections = [repr(c)
114 for cs in self._connections.values() for c in cs]
120 "connections": connections,
121 "message": "Unclosed connection pool",
123 self._loop.call_exception_handler(context)
126 class Connection(base):
129 pool: ConnectionPool,
131 reader: asyncio.StreamReader,
132 writer: asyncio.StreamWriter,
136 self._reader = reader
137 self._writer = writer
139 #: password command with the secret was sent
140 self.auth: bool = False
142 self.version: str | None = None
145 host = f"{self._host[0]}:{self._host[1]}"
146 return f"Connection<{host}>"
152 def release(self) -> None:
153 logging.debug('releasing %r', self)
155 self._pool._notify_release()
157 async def close(self) -> None:
160 logging.debug('closing %r', self)
163 self._pool._remove(self)
165 await self._writer.wait_closed()
166 except AttributeError: # wait_closed is new in 3.7
169 def __getattr__(self, name: str) -> Any:
170 """All unknown attributes are delegated to the reader and writer"""
171 if self._closed or not self.in_use:
172 raise ValueError("Can't use a closed or unacquired connection")
173 if hasattr(self._reader, name):
174 return getattr(self._reader, name)
175 return getattr(self._writer, name)
177 async def __aenter__(self) -> "Connection":
178 if self._closed or not self.in_use:
179 raise ValueError("Can't use a closed or unacquired connection")
184 exc_type: Optional[Type[BaseException]],
185 exc: Optional[BaseException],
186 tb: Optional[TracebackType],
190 def __del__(self) -> None:
193 context = {"connection": self, "message": "Unclosed connection"}
194 self._pool._loop.call_exception_handler(context)