]> kaliko git repositories - python-musicpdaio.git/blob - mpdaio/connection.py
c47cd7ea3f52705d49ca50f6ba40b0442278f1e9
[python-musicpdaio.git] / mpdaio / connection.py
1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
5 """
6
7 import asyncio
8 import contextlib
9 import logging
10
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
14
15 from . import HELLO_PREFIX
16 from .exceptions import MPDProtocolError
17
18 try:  # Python 3.7
19     base = contextlib.AbstractAsyncContextManager
20 except AttributeError as err:
21     base = object  # type: ignore
22
23 Server = str
24 Port = int
25 Host = Tuple[Server, Port]
26 log = logging.getLogger(__name__)
27
28
29 class ConnectionPool(base):
30     def __init__(
31         self,
32         max_connections: int = 1000,
33         loop: Optional[asyncio.AbstractEventLoop] = None,
34     ):
35         self.max_connections = max_connections
36         self._loop = loop or asyncio.get_event_loop()
37
38         self._connections: OrderedDict[Host,
39                                        List["Connection"]] = OrderedDict()
40         self._semaphore = asyncio.Semaphore(max_connections)
41
42     async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
43         host = (server, port)
44
45         # enforce the connection limit, releasing connections notifies
46         # the semaphore to release here
47         await self._semaphore.acquire()
48
49         connections = self._connections.setdefault(host, [])
50         log.debug('got %s in pool', len(connections))
51         # find an un-used connection for this host
52         connection = next(
53             (conn for conn in connections if not conn.in_use), None)
54         #if connection:
55         #    log.debug('reusing %s', connection)
56         if connection is None:
57             # disconnect the least-recently-used un-used connection to make space
58             # for a new connection. There will be at least one.
59             for conns_per_host in reversed(self._connections.values()):
60                 for conn in conns_per_host:
61                     if not conn.in_use:
62                         await conn.close()
63                         break
64             if server[0] in ['/', '@']:
65                 log.debug('about to connect unix socket %s', server)
66                 reader, writer = await asyncio.wait_for(
67                         asyncio.open_unix_connection(path=server),
68                         timeout
69                         )
70             else:
71                 log.debug('about to connect tcp socket %s:%s', *host)
72                 reader, writer = await asyncio.wait_for(
73                         asyncio.open_connection(server, port),
74                         timeout
75                         )
76             #log.debug('Connected to %s:%s', host[0], host[1])
77             connection = Connection(self, host, reader, writer)
78             await connection._hello()
79             connections.append(connection)
80
81         connection.in_use = True
82         # move current host to the front as most-recently used
83         self._connections.move_to_end(host, False)
84         log.debug('connection %s in use', connection)
85
86         return connection
87
88     async def close(self):
89         """Close all connections"""
90         connections = [c for cs in self._connections.values() for c in cs]
91         self._connections = OrderedDict()
92         log.info('Closing all connections')
93         for connection in connections:
94             await connection.close()
95
96     def _remove(self, connection):
97         conns_for_host = self._connections.get(connection._host)
98         if not conns_for_host:
99             return
100         conns_for_host[:] = [c for c in conns_for_host if c != connection]
101
102     def _notify_release(self):
103         self._semaphore.release()
104
105     async def __aenter__(self) -> "ConnectionPool":
106         return self
107
108     async def __aexit__(
109         self,
110         exc_type: Optional[Type[BaseException]],
111         exc: Optional[BaseException],
112         tb: Optional[TracebackType],
113     ) -> None:
114         await self.close()
115
116     def __del__(self) -> None:
117         connections = [repr(c)
118                        for cs in self._connections.values() for c in cs]
119         if not connections:
120             return
121
122         context = {
123             "pool": self,
124             "connections": connections,
125             "message": "Unclosed connection pool",
126         }
127         self._loop.call_exception_handler(context)
128
129
130 class Connection(base):
131     def __init__(
132         self,
133         pool: ConnectionPool,
134         host: Host,
135         reader: asyncio.StreamReader,
136         writer: asyncio.StreamWriter,
137     ):
138         self._host = host
139         self._pool = pool
140         self._reader = reader
141         self._writer = writer
142         self._closed = False
143         self.in_use = False
144
145     def __repr__(self):
146         host = f"{self._host[0]}:{self._host[1]}"
147         return f"Connection<{host}>#{id(self)}"
148
149     @property
150     def closed(self):
151         return self._closed
152
153     def release(self) -> None:
154         logging.debug('releasing %r', self)
155         self.in_use = False
156         self._pool._notify_release()
157
158     async def close(self) -> None:
159         if self._closed:
160             return
161         logging.debug('closing %r', self)
162         self._closed = True
163         self._writer.close()
164         self._pool._remove(self)
165         try:
166             await self._writer.wait_closed()
167         except AttributeError:  # wait_closed is new in 3.7
168             pass
169
170     async def _hello(self) -> None:
171         """Consume HELLO_PREFIX"""
172         self.in_use = True
173         data = await self.readuntil(b'\n')
174         rcv = data.decode('utf-8')
175         if not rcv.startswith(HELLO_PREFIX):
176             raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
177         log.debug('consumed hello prefix: %r', rcv)
178         self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
179         log.info('protocol version: %s', self.version)
180
181     def __getattr__(self, name: str) -> Any:
182         """All unknown attributes are delegated to the reader and writer"""
183         if self._closed or not self.in_use:
184             raise ValueError("Can't use a closed or unacquired connection")
185         if hasattr(self._reader, name):
186             return getattr(self._reader, name)
187         return getattr(self._writer, name)
188
189     async def __aenter__(self) -> "Connection":
190         if self._closed or not self.in_use:
191             raise ValueError("Can't use a closed or unacquired connection")
192         return self
193
194     async def __aexit__(
195         self,
196         exc_type: Optional[Type[BaseException]],
197         exc: Optional[BaseException],
198         tb: Optional[TracebackType],
199     ) -> None:
200         self.release()
201
202     def __del__(self) -> None:
203         if self._closed:
204             return
205         context = {"connection": self, "message": "Unclosed connection"}
206         self._pool._loop.call_exception_handler(context)