]> kaliko git repositories - python-musicpdaio.git/blob - mpdaio/connection.py
Add per connection command wrapper
[python-musicpdaio.git] / mpdaio / connection.py
1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
5 """
6
7 import asyncio
8 import contextlib
9 import logging
10
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
14
15 from . import HELLO_PREFIX
16 from .exceptions import MPDProtocolError
17
18 try:  # Python 3.7
19     base = contextlib.AbstractAsyncContextManager
20 except AttributeError as err:
21     base = object  # type: ignore
22
23 Server = str
24 Port = int
25 Host = Tuple[Server, Port]
26 log = logging.getLogger(__name__)
27
28
29 class ConnectionPool(base):
30     def __init__(
31         self,
32         max_connections: int = 1000,
33         loop: Optional[asyncio.AbstractEventLoop] = None,
34     ):
35         self.max_connections = max_connections
36         self._loop = loop or asyncio.get_event_loop()
37
38         self._connections: OrderedDict[Host,
39                                        List["Connection"]] = OrderedDict()
40         self._semaphore = asyncio.Semaphore(max_connections)
41
42     async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
43         host = (server, port)
44
45         # enforce the connection limit, releasing connections notifies
46         # the semaphore to release here
47         await self._semaphore.acquire()
48
49         connections = self._connections.setdefault(host, [])
50         log.debug('got %s in pool', len(connections))
51         # find an un-used connection for this host
52         connection = next(
53             (conn for conn in connections if not conn.in_use), None)
54         #if connection:
55         #    log.debug('reusing %s', connection)
56         if connection is None:
57             # disconnect the least-recently-used un-used connection to make space
58             # for a new connection. There will be at least one.
59             for conns_per_host in reversed(self._connections.values()):
60                 for conn in conns_per_host:
61                     if not conn.in_use:
62                         await conn.close()
63                         break
64
65             log.debug('about to connect %s', host)
66             reader, writer = await asyncio.wait_for(
67                     asyncio.open_connection(server, port),
68                     timeout
69                     )
70             #log.debug('Connected to %s:%s', host[0], host[1])
71             connection = Connection(self, host, reader, writer)
72             await connection._hello()
73             connections.append(connection)
74
75         connection.in_use = True
76         # move current host to the front as most-recently used
77         self._connections.move_to_end(host, False)
78         log.debug('connection %s in use', connection)
79
80         return connection
81
82     async def close(self):
83         """Close all connections"""
84         connections = [c for cs in self._connections.values() for c in cs]
85         self._connections = OrderedDict()
86         log.info('Closing all connections')
87         for connection in connections:
88             await connection.close()
89
90     def _remove(self, connection):
91         conns_for_host = self._connections.get(connection._host)
92         if not conns_for_host:
93             return
94         conns_for_host[:] = [c for c in conns_for_host if c != connection]
95
96     def _notify_release(self):
97         self._semaphore.release()
98
99     async def __aenter__(self) -> "ConnectionPool":
100         return self
101
102     async def __aexit__(
103         self,
104         exc_type: Optional[Type[BaseException]],
105         exc: Optional[BaseException],
106         tb: Optional[TracebackType],
107     ) -> None:
108         await self.close()
109
110     def __del__(self) -> None:
111         connections = [repr(c)
112                        for cs in self._connections.values() for c in cs]
113         if not connections:
114             return
115
116         context = {
117             "pool": self,
118             "connections": connections,
119             "message": "Unclosed connection pool",
120         }
121         self._loop.call_exception_handler(context)
122
123
124 class Connection(base):
125     def __init__(
126         self,
127         pool: ConnectionPool,
128         host: Host,
129         reader: asyncio.StreamReader,
130         writer: asyncio.StreamWriter,
131     ):
132         self._host = host
133         self._pool = pool
134         self._reader = reader
135         self._writer = writer
136         self._closed = False
137         self.in_use = False
138
139     def __repr__(self):
140         host = f"{self._host[0]}:{self._host[1]}"
141         return f"Connection<{host}>#{id(self)}"
142
143     @property
144     def closed(self):
145         return self._closed
146
147     def release(self) -> None:
148         logging.debug('releasing %r', self)
149         self.in_use = False
150         self._pool._notify_release()
151
152     async def close(self) -> None:
153         if self._closed:
154             return
155         logging.debug('closing %r', self)
156         self._closed = True
157         self._writer.close()
158         self._pool._remove(self)
159         try:
160             await self._writer.wait_closed()
161         except AttributeError:  # wait_closed is new in 3.7
162             pass
163
164     async def _hello(self) -> None:
165         """Consume HELLO_PREFIX"""
166         self.in_use = True
167         data = await self.readuntil(b'\n')
168         rcv = data.decode('utf-8')
169         if not rcv.startswith(HELLO_PREFIX):
170             raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
171         log.debug('consumed hello prefix: %r', rcv)
172         self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
173         log.info('protocol version: %s', self.version)
174
175
176     def __getattr__(self, name: str) -> Any:
177         """All unknown attributes are delegated to the reader and writer"""
178         if self._closed or not self.in_use:
179             raise ValueError("Can't use a closed or unacquired connection")
180         if hasattr(self._reader, name):
181             return getattr(self._reader, name)
182         return getattr(self._writer, name)
183
184     async def __aenter__(self) -> "Connection":
185         if self._closed or not self.in_use:
186             raise ValueError("Can't use a closed or unacquired connection")
187         return self
188
189     async def __aexit__(
190         self,
191         exc_type: Optional[Type[BaseException]],
192         exc: Optional[BaseException],
193         tb: Optional[TracebackType],
194     ) -> None:
195         self.release()
196
197     def __del__(self) -> None:
198         if self._closed:
199             return
200         context = {"connection": self, "message": "Unclosed connection"}
201         self._pool._loop.call_exception_handler(context)