1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024 kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
15 from . import HELLO_PREFIX
16 from .exceptions import MPDProtocolError
19 base = contextlib.AbstractAsyncContextManager
20 except AttributeError as err:
21 base = object # type: ignore
25 Host = Tuple[Server, Port]
26 log = logging.getLogger(__name__)
29 class ConnectionPool(base):
32 max_connections: int = 1000,
33 loop: Optional[asyncio.AbstractEventLoop] = None,
35 self.max_connections = max_connections
36 self._loop = loop or asyncio.get_event_loop()
38 self._connections: OrderedDict[Host,
39 List["Connection"]] = OrderedDict()
40 self._semaphore = asyncio.Semaphore(max_connections)
42 async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
45 # enforce the connection limit, releasing connections notifies
46 # the semaphore to release here
47 await self._semaphore.acquire()
49 connections = self._connections.setdefault(host, [])
50 log.debug('got %s in pool', len(connections))
51 # find an un-used connection for this host
53 (conn for conn in connections if not conn.in_use), None)
55 # log.debug('reusing %s', connection)
56 if connection is None:
57 # disconnect the least-recently-used un-used connection to make space
58 # for a new connection. There will be at least one.
59 for conns_per_host in reversed(self._connections.values()):
60 for conn in conns_per_host:
64 if server[0] in ['/', '@']:
65 log.debug('about to connect unix socket %s', server)
66 reader, writer = await asyncio.wait_for(
67 asyncio.open_unix_connection(path=server),
71 log.debug('about to connect tcp socket %s:%s', *host)
72 reader, writer = await asyncio.wait_for(
73 asyncio.open_connection(server, port),
76 #log.debug('Connected to %s:%s', host[0], host[1])
77 connection = Connection(self, host, reader, writer)
78 await connection._hello()
79 connections.append(connection)
81 connection.in_use = True
82 # move current host to the front as most-recently used
83 self._connections.move_to_end(host, False)
84 log.debug('connection %s in use', connection)
88 async def close(self):
89 """Close all connections"""
90 connections = [c for cs in self._connections.values() for c in cs]
91 self._connections = OrderedDict()
92 log.info('Closing all connections')
93 for connection in connections:
94 await connection.close()
96 def _remove(self, connection):
97 conns_for_host = self._connections.get(connection._host)
98 if not conns_for_host:
100 conns_for_host[:] = [c for c in conns_for_host if c != connection]
102 def _notify_release(self):
103 self._semaphore.release()
105 async def __aenter__(self) -> "ConnectionPool":
110 exc_type: Optional[Type[BaseException]],
111 exc: Optional[BaseException],
112 tb: Optional[TracebackType],
116 def __del__(self) -> None:
117 connections = [repr(c)
118 for cs in self._connections.values() for c in cs]
124 "connections": connections,
125 "message": "Unclosed connection pool",
127 self._loop.call_exception_handler(context)
130 class Connection(base):
133 pool: ConnectionPool,
135 reader: asyncio.StreamReader,
136 writer: asyncio.StreamWriter,
140 self._reader = reader
141 self._writer = writer
146 host = f"{self._host[0]}:{self._host[1]}"
147 return f"Connection<{host}>#{id(self)}"
153 def release(self) -> None:
154 logging.debug('releasing %r', self)
156 self._pool._notify_release()
158 async def close(self) -> None:
161 logging.debug('closing %r', self)
164 self._pool._remove(self)
166 await self._writer.wait_closed()
167 except AttributeError: # wait_closed is new in 3.7
170 async def _hello(self) -> None:
171 """Consume HELLO_PREFIX"""
173 data = await self.readuntil(b'\n')
174 rcv = data.decode('utf-8')
175 if not rcv.startswith(HELLO_PREFIX):
176 raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
177 log.debug('consumed hello prefix: %r', rcv)
178 self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
179 log.info('protocol version: %s', self.version)
181 def __getattr__(self, name: str) -> Any:
182 """All unknown attributes are delegated to the reader and writer"""
183 if self._closed or not self.in_use:
184 raise ValueError("Can't use a closed or unacquired connection")
185 if hasattr(self._reader, name):
186 return getattr(self._reader, name)
187 return getattr(self._writer, name)
189 async def __aenter__(self) -> "Connection":
190 if self._closed or not self.in_use:
191 raise ValueError("Can't use a closed or unacquired connection")
196 exc_type: Optional[Type[BaseException]],
197 exc: Optional[BaseException],
198 tb: Optional[TracebackType],
202 def __del__(self) -> None:
205 context = {"connection": self, "message": "Unclosed connection"}
206 self._pool._loop.call_exception_handler(context)