1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024 kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
15 from . import HELLO_PREFIX
16 from .exceptions import MPDProtocolError
19 base = contextlib.AbstractAsyncContextManager
20 except AttributeError as err:
21 base = object # type: ignore
25 Host = Tuple[Server, Port]
26 log = logging.getLogger(__name__)
29 class ConnectionPool(base):
32 max_connections: int = 1000,
33 loop: Optional[asyncio.AbstractEventLoop] = None,
35 self.max_connections = max_connections
36 self._loop = loop or asyncio.get_event_loop()
38 self._connections: OrderedDict[Host,
39 List["Connection"]] = OrderedDict()
40 self._semaphore = asyncio.Semaphore(max_connections)
42 async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
45 # enforce the connection limit, releasing connections notifies
46 # the semaphore to release here
47 await self._semaphore.acquire()
49 connections = self._connections.setdefault(host, [])
50 log.debug('got %s in pool', len(connections))
51 # find an un-used connection for this host
53 (conn for conn in connections if not conn.in_use), None)
54 if connection is None:
55 # disconnect the least-recently-used un-used connection to make space
56 # for a new connection. There will be at least one.
57 for conns_per_host in reversed(self._connections.values()):
58 for conn in conns_per_host:
63 log.debug('about to connect %s', host)
64 reader, writer = await asyncio.wait_for(
65 asyncio.open_connection(server, port),
68 log.info('Connected to %s:%s', host[0], host[1])
69 connection = Connection(self, host, reader, writer)
70 await connection._hello()
71 connections.append(connection)
73 connection.in_use = True
74 # move current host to the front as most-recently used
75 self._connections.move_to_end(host, False)
76 log.debug('connection %s in use', connection)
80 async def close(self):
81 """Close all connections"""
82 connections = [c for cs in self._connections.values() for c in cs]
83 self._connections = OrderedDict()
84 for connection in connections:
85 await connection.close()
87 def _remove(self, connection):
88 conns_for_host = self._connections.get(connection._host)
89 if not conns_for_host:
91 conns_for_host[:] = [c for c in conns_for_host if c != connection]
93 def _notify_release(self):
94 self._semaphore.release()
96 async def __aenter__(self) -> "ConnectionPool":
101 exc_type: Optional[Type[BaseException]],
102 exc: Optional[BaseException],
103 tb: Optional[TracebackType],
107 def __del__(self) -> None:
108 connections = [repr(c)
109 for cs in self._connections.values() for c in cs]
115 "connections": connections,
116 "message": "Unclosed connection pool",
118 self._loop.call_exception_handler(context)
121 class Connection(base):
124 pool: ConnectionPool,
126 reader: asyncio.StreamReader,
127 writer: asyncio.StreamWriter,
131 self._reader = reader
132 self._writer = writer
137 host = f"{self._host[0]}:{self._host[1]}"
138 return f"Connection<{host}>#{id(self)}"
144 def release(self) -> None:
145 logging.debug('releasing %r', self)
147 self._pool._notify_release()
149 async def close(self) -> None:
152 logging.debug('closing %r', self)
155 self._pool._remove(self)
157 await self._writer.wait_closed()
158 except AttributeError: # wait_closed is new in 3.7
161 async def _hello(self) -> None:
162 """Consume HELLO_PREFIX"""
164 data = await self.readuntil(b'\n')
165 rcv = data.decode('utf-8')
166 if not rcv.startswith(HELLO_PREFIX):
167 raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
168 log.debug('consumed hello prefix: %r', rcv)
169 self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
170 log.info('protocol version: %s', self.version)
173 def __getattr__(self, name: str) -> Any:
174 """All unknown attributes are delegated to the reader and writer"""
175 if self._closed or not self.in_use:
176 raise ValueError("Can't use a closed or unacquired connection")
177 if hasattr(self._reader, name):
178 return getattr(self._reader, name)
179 return getattr(self._writer, name)
181 async def __aenter__(self) -> "Connection":
182 if self._closed or not self.in_use:
183 raise ValueError("Can't use a closed or unacquired connection")
188 exc_type: Optional[Type[BaseException]],
189 exc: Optional[BaseException],
190 tb: Optional[TracebackType],
194 def __del__(self) -> None:
197 context = {"connection": self, "message": "Unclosed connection"}
198 self._pool._loop.call_exception_handler(context)