]> kaliko git repositories - python-musicpdaio.git/blob - mpdaio/connection.py
Some improvement on ConnectionPool
[python-musicpdaio.git] / mpdaio / connection.py
1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
5 """
6
7 import asyncio
8 import contextlib
9 import logging
10
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
14
15 from . import HELLO_PREFIX
16 from .exceptions import MPDProtocolError
17
18 try:  # Python 3.7
19     base = contextlib.AbstractAsyncContextManager
20 except AttributeError as err:
21     base = object  # type: ignore
22
23 Server = str
24 Port = int
25 Host = Tuple[Server, Port]
26 log = logging.getLogger(__name__)
27
28
29 class ConnectionPool(base):
30     def __init__(
31         self,
32         max_connections: int = 1000,
33         loop: Optional[asyncio.AbstractEventLoop] = None,
34     ):
35         self.max_connections = max_connections
36         self._loop = loop or asyncio.get_event_loop()
37
38         self._connections: OrderedDict[Host,
39                                        List["Connection"]] = OrderedDict()
40         self._semaphore = asyncio.Semaphore(max_connections)
41
42     async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
43         host = (server, port)
44
45         # enforce the connection limit, releasing connections notifies
46         # the semaphore to release here
47         await self._semaphore.acquire()
48
49         connections = self._connections.setdefault(host, [])
50         log.debug('got %s in pool', len(connections))
51         # find an un-used connection for this host
52         connection = next(
53             (conn for conn in connections if not conn.in_use), None)
54         if connection is None:
55             # disconnect the least-recently-used un-used connection to make space
56             # for a new connection. There will be at least one.
57             for conns_per_host in reversed(self._connections.values()):
58                 for conn in conns_per_host:
59                     if not conn.in_use:
60                         await conn.close()
61                         break
62
63             log.debug('about to connect %s', host)
64             reader, writer = await asyncio.wait_for(
65                     asyncio.open_connection(server, port),
66                     timeout
67                     )
68             log.info('Connected to %s:%s', host[0], host[1])
69             connection = Connection(self, host, reader, writer)
70             await connection._hello()
71             connections.append(connection)
72
73         connection.in_use = True
74         # move current host to the front as most-recently used
75         self._connections.move_to_end(host, False)
76         log.debug('connection %s in use', connection)
77
78         return connection
79
80     async def close(self):
81         """Close all connections"""
82         connections = [c for cs in self._connections.values() for c in cs]
83         self._connections = OrderedDict()
84         for connection in connections:
85             await connection.close()
86
87     def _remove(self, connection):
88         conns_for_host = self._connections.get(connection._host)
89         if not conns_for_host:
90             return
91         conns_for_host[:] = [c for c in conns_for_host if c != connection]
92
93     def _notify_release(self):
94         self._semaphore.release()
95
96     async def __aenter__(self) -> "ConnectionPool":
97         return self
98
99     async def __aexit__(
100         self,
101         exc_type: Optional[Type[BaseException]],
102         exc: Optional[BaseException],
103         tb: Optional[TracebackType],
104     ) -> None:
105         await self.close()
106
107     def __del__(self) -> None:
108         connections = [repr(c)
109                        for cs in self._connections.values() for c in cs]
110         if not connections:
111             return
112
113         context = {
114             "pool": self,
115             "connections": connections,
116             "message": "Unclosed connection pool",
117         }
118         self._loop.call_exception_handler(context)
119
120
121 class Connection(base):
122     def __init__(
123         self,
124         pool: ConnectionPool,
125         host: Host,
126         reader: asyncio.StreamReader,
127         writer: asyncio.StreamWriter,
128     ):
129         self._host = host
130         self._pool = pool
131         self._reader = reader
132         self._writer = writer
133         self._closed = False
134         self.in_use = False
135
136     def __repr__(self):
137         host = f"{self._host[0]}:{self._host[1]}"
138         return f"Connection<{host}>#{id(self)}"
139
140     @property
141     def closed(self):
142         return self._closed
143
144     def release(self) -> None:
145         logging.debug('releasing %r', self)
146         self.in_use = False
147         self._pool._notify_release()
148
149     async def close(self) -> None:
150         if self._closed:
151             return
152         logging.debug('closing %r', self)
153         self._closed = True
154         self._writer.close()
155         self._pool._remove(self)
156         try:
157             await self._writer.wait_closed()
158         except AttributeError:  # wait_closed is new in 3.7
159             pass
160
161     async def _hello(self) -> None:
162         """Consume HELLO_PREFIX"""
163         self.in_use = True
164         data = await self.readuntil(b'\n')
165         rcv = data.decode('utf-8')
166         if not rcv.startswith(HELLO_PREFIX):
167             raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
168         log.debug('consumed hello prefix: %r', rcv)
169         self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
170         log.info('protocol version: %s', self.version)
171
172
173     def __getattr__(self, name: str) -> Any:
174         """All unknown attributes are delegated to the reader and writer"""
175         if self._closed or not self.in_use:
176             raise ValueError("Can't use a closed or unacquired connection")
177         if hasattr(self._reader, name):
178             return getattr(self._reader, name)
179         return getattr(self._writer, name)
180
181     async def __aenter__(self) -> "Connection":
182         if self._closed or not self.in_use:
183             raise ValueError("Can't use a closed or unacquired connection")
184         return self
185
186     async def __aexit__(
187         self,
188         exc_type: Optional[Type[BaseException]],
189         exc: Optional[BaseException],
190         tb: Optional[TracebackType],
191     ) -> None:
192         self.release()
193
194     def __del__(self) -> None:
195         if self._closed:
196             return
197         context = {"connection": self, "message": "Unclosed connection"}
198         self._pool._loop.call_exception_handler(context)