]> kaliko git repositories - python-musicpdaio.git/blob - mpdaio/connection.py
d00f85af9755f81b84299dfe999ac624bb8f0e4c
[python-musicpdaio.git] / mpdaio / connection.py
1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
5 """
6
7 import asyncio
8 import contextlib
9 import logging
10
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
14
15 from .const import HELLO_PREFIX
16 from .exceptions import MPDProtocolError
17
18 try:  # Python 3.7
19     base = contextlib.AbstractAsyncContextManager
20 except AttributeError as err:
21     base = object  # type: ignore
22
23 Server = str
24 Port = int
25 Host = Tuple[Server, Port]
26 log = logging.getLogger(__name__)
27
28
29 class ConnectionPool(base):
30     def __init__(
31         self,
32         max_connections: int = 1000,
33         loop: Optional[asyncio.AbstractEventLoop] = None,
34     ):
35         self.max_connections = max_connections
36         self._loop = loop or asyncio.get_event_loop()
37
38         self._connections: OrderedDict[Host,
39                                        List["Connection"]] = OrderedDict()
40         self._semaphore = asyncio.Semaphore(max_connections)
41
42     async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
43         host = (server, port)
44
45         # enforce the connection limit, releasing connections notifies
46         # the semaphore to release here
47         await self._semaphore.acquire()
48
49         connections = self._connections.setdefault(host, [])
50         log.debug('got %s in pool', len(connections))
51         # find an un-used connection for this host
52         connection = next(
53             (conn for conn in connections if not conn.in_use), None)
54         #if connection:
55         #    log.debug('reusing %s', connection)
56         if connection is None:
57             # disconnect the least-recently-used un-used connection to make space
58             # for a new connection. There will be at least one.
59             for conns_per_host in reversed(self._connections.values()):
60                 for conn in conns_per_host:
61                     if not conn.in_use:
62                         await conn.close()
63                         break
64             if server[0] in ['/', '@']:
65                 log.debug('about to connect unix socket %s', server)
66                 reader, writer = await asyncio.wait_for(
67                         asyncio.open_unix_connection(path=server),
68                         timeout
69                         )
70             else:
71                 log.debug('about to connect tcp socket %s:%s', *host)
72                 reader, writer = await asyncio.wait_for(
73                         asyncio.open_connection(server, port),
74                         timeout
75                         )
76             #log.debug('Connected to %s:%s', host[0], host[1])
77             connection = Connection(self, host, reader, writer)
78             await connection._hello()
79             connections.append(connection)
80
81         connection.in_use = True
82         # move current host to the front as most-recently used
83         self._connections.move_to_end(host, False)
84         log.debug('connection %s in use', connection)
85
86         return connection
87
88     async def close(self):
89         """Close all connections"""
90         connections = [c for cs in self._connections.values() for c in cs]
91         self._connections = OrderedDict()
92         log.info('Closing all connections')
93         for connection in connections:
94             await connection.close()
95
96     def _remove(self, connection):
97         conns_for_host = self._connections.get(connection._host)
98         if not conns_for_host:
99             return
100         conns_for_host[:] = [c for c in conns_for_host if c != connection]
101
102     def _notify_release(self):
103         self._semaphore.release()
104
105     async def __aenter__(self) -> "ConnectionPool":
106         return self
107
108     async def __aexit__(
109         self,
110         exc_type: Optional[Type[BaseException]],
111         exc: Optional[BaseException],
112         tb: Optional[TracebackType],
113     ) -> None:
114         await self.close()
115
116     def __del__(self) -> None:
117         connections = [repr(c)
118                        for cs in self._connections.values() for c in cs]
119         if not connections:
120             return
121
122         context = {
123             "pool": self,
124             "connections": connections,
125             "message": "Unclosed connection pool",
126         }
127         self._loop.call_exception_handler(context)
128
129
130 class Connection(base):
131     def __init__(
132         self,
133         pool: ConnectionPool,
134         host: Host,
135         reader: asyncio.StreamReader,
136         writer: asyncio.StreamWriter,
137     ):
138         self._host = host
139         self._pool = pool
140         self._reader = reader
141         self._writer = writer
142         self._closed = False
143         self.auth = False
144         self.in_use = False
145         self.version: str | None = None
146
147     def __repr__(self):
148         host = f"{self._host[0]}:{self._host[1]}"
149         return f"Connection<{host}>"
150
151     @property
152     def closed(self):
153         return self._closed
154
155     def release(self) -> None:
156         logging.debug('releasing %r', self)
157         self.in_use = False
158         self._pool._notify_release()
159
160     async def close(self) -> None:
161         if self._closed:
162             return
163         logging.debug('closing %r', self)
164         self._closed = True
165         self._writer.close()
166         self._pool._remove(self)
167         try:
168             await self._writer.wait_closed()
169         except AttributeError:  # wait_closed is new in 3.7
170             pass
171
172     async def _hello(self) -> None:
173         """Consume HELLO_PREFIX"""
174         self.in_use = True
175         data = await self.readuntil(b'\n')
176         rcv = data.decode('utf-8')
177         if not rcv.startswith(HELLO_PREFIX):
178             raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
179         log.debug('consumed hello prefix: %r', rcv)
180         self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
181
182     def __getattr__(self, name: str) -> Any:
183         """All unknown attributes are delegated to the reader and writer"""
184         if self._closed or not self.in_use:
185             raise ValueError("Can't use a closed or unacquired connection")
186         if hasattr(self._reader, name):
187             return getattr(self._reader, name)
188         return getattr(self._writer, name)
189
190     async def __aenter__(self) -> "Connection":
191         if self._closed or not self.in_use:
192             raise ValueError("Can't use a closed or unacquired connection")
193         return self
194
195     async def __aexit__(
196         self,
197         exc_type: Optional[Type[BaseException]],
198         exc: Optional[BaseException],
199         tb: Optional[TracebackType],
200     ) -> None:
201         self.release()
202
203     def __del__(self) -> None:
204         if self._closed:
205             return
206         context = {"connection": self, "message": "Unclosed connection"}
207         self._pool._loop.call_exception_handler(context)