1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024 kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
15 from . import HELLO_PREFIX
16 from .exceptions import MPDProtocolError
19 base = contextlib.AbstractAsyncContextManager
20 except AttributeError as err:
21 base = object # type: ignore
25 Host = Tuple[Server, Port]
26 log = logging.getLogger(__name__)
29 class ConnectionPool(base):
32 max_connections: int = 1000,
33 loop: Optional[asyncio.AbstractEventLoop] = None,
35 self.max_connections = max_connections
36 self._loop = loop or asyncio.get_event_loop()
38 self._connections: OrderedDict[Host,
39 List["Connection"]] = OrderedDict()
40 self._semaphore = asyncio.Semaphore(max_connections)
42 async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
45 # enforce the connection limit, releasing connections notifies
46 # the semaphore to release here
47 await self._semaphore.acquire()
49 connections = self._connections.setdefault(host, [])
50 log.debug('got %s in pool', len(connections))
51 # find an un-used connection for this host
53 (conn for conn in connections if not conn.in_use), None)
55 # log.debug('reusing %s', connection)
56 if connection is None:
57 # disconnect the least-recently-used un-used connection to make space
58 # for a new connection. There will be at least one.
59 for conns_per_host in reversed(self._connections.values()):
60 for conn in conns_per_host:
65 log.debug('about to connect %s', host)
66 reader, writer = await asyncio.wait_for(
67 asyncio.open_connection(server, port),
70 #log.debug('Connected to %s:%s', host[0], host[1])
71 connection = Connection(self, host, reader, writer)
72 await connection._hello()
73 connections.append(connection)
75 connection.in_use = True
76 # move current host to the front as most-recently used
77 self._connections.move_to_end(host, False)
78 log.debug('connection %s in use', connection)
82 async def close(self):
83 """Close all connections"""
84 connections = [c for cs in self._connections.values() for c in cs]
85 self._connections = OrderedDict()
86 log.info('Closing all connections')
87 for connection in connections:
88 await connection.close()
90 def _remove(self, connection):
91 conns_for_host = self._connections.get(connection._host)
92 if not conns_for_host:
94 conns_for_host[:] = [c for c in conns_for_host if c != connection]
96 def _notify_release(self):
97 self._semaphore.release()
99 async def __aenter__(self) -> "ConnectionPool":
104 exc_type: Optional[Type[BaseException]],
105 exc: Optional[BaseException],
106 tb: Optional[TracebackType],
110 def __del__(self) -> None:
111 connections = [repr(c)
112 for cs in self._connections.values() for c in cs]
118 "connections": connections,
119 "message": "Unclosed connection pool",
121 self._loop.call_exception_handler(context)
124 class Connection(base):
127 pool: ConnectionPool,
129 reader: asyncio.StreamReader,
130 writer: asyncio.StreamWriter,
134 self._reader = reader
135 self._writer = writer
140 host = f"{self._host[0]}:{self._host[1]}"
141 return f"Connection<{host}>#{id(self)}"
147 def release(self) -> None:
148 logging.debug('releasing %r', self)
150 self._pool._notify_release()
152 async def close(self) -> None:
155 logging.debug('closing %r', self)
158 self._pool._remove(self)
160 await self._writer.wait_closed()
161 except AttributeError: # wait_closed is new in 3.7
164 async def _hello(self) -> None:
165 """Consume HELLO_PREFIX"""
167 data = await self.readuntil(b'\n')
168 rcv = data.decode('utf-8')
169 if not rcv.startswith(HELLO_PREFIX):
170 raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
171 log.debug('consumed hello prefix: %r', rcv)
172 self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
173 log.info('protocol version: %s', self.version)
176 def __getattr__(self, name: str) -> Any:
177 """All unknown attributes are delegated to the reader and writer"""
178 if self._closed or not self.in_use:
179 raise ValueError("Can't use a closed or unacquired connection")
180 if hasattr(self._reader, name):
181 return getattr(self._reader, name)
182 return getattr(self._writer, name)
184 async def __aenter__(self) -> "Connection":
185 if self._closed or not self.in_use:
186 raise ValueError("Can't use a closed or unacquired connection")
191 exc_type: Optional[Type[BaseException]],
192 exc: Optional[BaseException],
193 tb: Optional[TracebackType],
197 def __del__(self) -> None:
200 context = {"connection": self, "message": "Unclosed connection"}
201 self._pool._loop.call_exception_handler(context)