]> kaliko git repositories - python-musicpdaio.git/blob - mpdaio/connection.py
Move protocol logic in the client
[python-musicpdaio.git] / mpdaio / connection.py
1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
5 """
6
7 import asyncio
8 import contextlib
9 import logging
10
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
14
15 try:  # Python 3.7
16     base = contextlib.AbstractAsyncContextManager
17 except AttributeError as err:
18     base = object  # type: ignore
19
20 Server = str
21 Port = int
22 Host = Tuple[Server, Port]
23 log = logging.getLogger(__name__)
24
25
26 class ConnectionPool(base):
27     def __init__(
28         self,
29         max_connections: int = 1000,
30         loop: Optional[asyncio.AbstractEventLoop] = None,
31     ):
32         self.max_connections = max_connections
33         self._loop = loop or asyncio.get_event_loop()
34
35         self._connections: OrderedDict[Host,
36                                        List["Connection"]] = OrderedDict()
37         self._semaphore = asyncio.Semaphore(max_connections)
38
39     async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
40         host = (server, port)
41
42         # enforce the connection limit, releasing connections notifies
43         # the semaphore to release here
44         await self._semaphore.acquire()
45
46         connections = self._connections.setdefault(host, [])
47         log.debug('got %s in pool', len(connections))
48         # find an un-used connection for this host
49         connection = next(
50             (conn for conn in connections if not conn.in_use), None)
51         #if connection:
52         #    log.debug('reusing %s', connection)
53         if connection is None:
54             # disconnect the least-recently-used un-used connection to make space
55             # for a new connection. There will be at least one.
56             for conns_per_host in reversed(self._connections.values()):
57                 for conn in conns_per_host:
58                     if not conn.in_use:
59                         await conn.close()
60                         break
61             if server[0] in ['/', '@']:
62                 log.debug('about to connect unix socket %s', server)
63                 reader, writer = await asyncio.wait_for(
64                         asyncio.open_unix_connection(path=server),
65                         timeout
66                         )
67             else:
68                 log.debug('about to connect tcp socket %s:%s', *host)
69                 reader, writer = await asyncio.wait_for(
70                         asyncio.open_connection(server, port),
71                         timeout
72                         )
73             #log.debug('Connected to %s:%s', host[0], host[1])
74             connection = Connection(self, host, reader, writer)
75             connections.append(connection)
76
77         connection.in_use = True
78         # move current host to the front as most-recently used
79         self._connections.move_to_end(host, False)
80         log.debug('connection %s in use', connection)
81
82         return connection
83
84     async def close(self):
85         """Close all connections"""
86         connections = [c for cs in self._connections.values() for c in cs]
87         self._connections = OrderedDict()
88         log.info('Closing all connections')
89         for connection in connections:
90             await connection.close()
91
92     def _remove(self, connection):
93         conns_for_host = self._connections.get(connection._host)
94         if not conns_for_host:
95             return
96         conns_for_host[:] = [c for c in conns_for_host if c != connection]
97
98     def _notify_release(self):
99         self._semaphore.release()
100
101     async def __aenter__(self) -> "ConnectionPool":
102         return self
103
104     async def __aexit__(
105         self,
106         exc_type: Optional[Type[BaseException]],
107         exc: Optional[BaseException],
108         tb: Optional[TracebackType],
109     ) -> None:
110         await self.close()
111
112     def __del__(self) -> None:
113         connections = [repr(c)
114                        for cs in self._connections.values() for c in cs]
115         if not connections:
116             return
117
118         context = {
119             "pool": self,
120             "connections": connections,
121             "message": "Unclosed connection pool",
122         }
123         self._loop.call_exception_handler(context)
124
125
126 class Connection(base):
127     def __init__(
128         self,
129         pool: ConnectionPool,
130         host: Host,
131         reader: asyncio.StreamReader,
132         writer: asyncio.StreamWriter,
133     ):
134         self._host = host
135         self._pool = pool
136         self._reader = reader
137         self._writer = writer
138         self._closed = False
139         #: password command with the secret was sent
140         self.auth: bool = False
141         self.in_use = False
142         self.version: str | None = None
143
144     def __repr__(self):
145         host = f"{self._host[0]}:{self._host[1]}"
146         return f"Connection<{host}>"
147
148     @property
149     def closed(self):
150         return self._closed
151
152     def release(self) -> None:
153         logging.debug('releasing %r', self)
154         self.in_use = False
155         self._pool._notify_release()
156
157     async def close(self) -> None:
158         if self._closed:
159             return
160         logging.debug('closing %r', self)
161         self._closed = True
162         self._writer.close()
163         self._pool._remove(self)
164         try:
165             await self._writer.wait_closed()
166         except AttributeError:  # wait_closed is new in 3.7
167             pass
168
169     def __getattr__(self, name: str) -> Any:
170         """All unknown attributes are delegated to the reader and writer"""
171         if self._closed or not self.in_use:
172             raise ValueError("Can't use a closed or unacquired connection")
173         if hasattr(self._reader, name):
174             return getattr(self._reader, name)
175         return getattr(self._writer, name)
176
177     async def __aenter__(self) -> "Connection":
178         if self._closed or not self.in_use:
179             raise ValueError("Can't use a closed or unacquired connection")
180         return self
181
182     async def __aexit__(
183         self,
184         exc_type: Optional[Type[BaseException]],
185         exc: Optional[BaseException],
186         tb: Optional[TracebackType],
187     ) -> None:
188         self.release()
189
190     def __del__(self) -> None:
191         if self._closed:
192             return
193         context = {"connection": self, "message": "Unclosed connection"}
194         self._pool._loop.call_exception_handler(context)