]> kaliko git repositories - python-musicpdaio.git/blob - mpdaio/connection.py
Partially implement protocol (AsyncIO POC)
[python-musicpdaio.git] / mpdaio / connection.py
1 # -*- coding: utf-8 -*-
2 # SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
3 # SPDX-License-Identifier: LGPL-3.0-or-later
4 """https://stackoverflow.com/questions/55879847/asyncio-how-to-reuse-a-socket#%E2%80%A6
5 """
6
7 import asyncio
8 import contextlib
9 import logging
10
11 from collections import OrderedDict
12 from types import TracebackType
13 from typing import Any, List, Optional, Tuple, Type
14
15
16 try:  # Python 3.7
17     base = contextlib.AbstractAsyncContextManager
18 except AttributeError as err:
19     base = object  # type: ignore
20
21 Server = str
22 Port = int
23 Host = Tuple[Server, Port]
24 log = logging.getLogger(__name__)
25
26
27 class ConnectionPool(base):
28     def __init__(
29         self,
30         max_connections: int = 1000,
31         loop: Optional[asyncio.AbstractEventLoop] = None,
32     ):
33         self.max_connections = max_connections
34         self._loop = loop or asyncio.get_event_loop()
35
36         self._connections: OrderedDict[Host,
37                                        List["Connection"]] = OrderedDict()
38         self._semaphore = asyncio.Semaphore(max_connections)
39
40     async def connect(self, server: Server, port: Port) -> "Connection":
41         host = (server, port)
42
43         # enforce the connection limit, releasing connections notifies
44         # the semaphore to release here
45         await self._semaphore.acquire()
46
47         connections = self._connections.setdefault(host, [])
48         log.debug('got %s in pool', len(connections))
49         # find an un-used connection for this host
50         connection = next(
51             (conn for conn in connections if not conn.in_use), None)
52         if connection is None:
53             # disconnect the least-recently-used un-used connection to make space
54             # for a new connection. There will be at least one.
55             for conns_per_host in reversed(self._connections.values()):
56                 for conn in conns_per_host:
57                     if not conn.in_use:
58                         await conn.close()
59                         break
60
61             reader, writer = await asyncio.open_connection(server, port)
62             connection = Connection(self, host, reader, writer)
63             connections.append(connection)
64
65         connection.in_use = True
66         # move current host to the front as most-recently used
67         self._connections.move_to_end(host, False)
68
69         return connection
70
71     async def close(self):
72         """Close all connections"""
73         connections = [c for cs in self._connections.values() for c in cs]
74         self._connections = OrderedDict()
75         for connection in connections:
76             await connection.close()
77
78     def _remove(self, connection):
79         conns_for_host = self._connections.get(connection._host)
80         if not conns_for_host:
81             return
82         conns_for_host[:] = [c for c in conns_for_host if c != connection]
83
84     def _notify_release(self):
85         self._semaphore.release()
86
87     async def __aenter__(self) -> "ConnectionPool":
88         return self
89
90     async def __aexit__(
91         self,
92         exc_type: Optional[Type[BaseException]],
93         exc: Optional[BaseException],
94         tb: Optional[TracebackType],
95     ) -> None:
96         await self.close()
97
98     def __del__(self) -> None:
99         connections = [repr(c)
100                        for cs in self._connections.values() for c in cs]
101         if not connections:
102             return
103
104         context = {
105             "pool": self,
106             "connections": connections,
107             "message": "Unclosed connection pool",
108         }
109         self._loop.call_exception_handler(context)
110
111
112 class Connection(base):
113     def __init__(
114         self,
115         pool: ConnectionPool,
116         host: Host,
117         reader: asyncio.StreamReader,
118         writer: asyncio.StreamWriter,
119     ):
120         self._host = host
121         self._pool = pool
122         self._reader = reader
123         self._writer = writer
124         self._closed = False
125         self.in_use = False
126
127     def __repr__(self):
128         host = f"{self._host[0]}:{self._host[1]}"
129         return f"Connection<{host}>"
130
131     @property
132     def closed(self):
133         return self._closed
134
135     def release(self) -> None:
136         logging.debug('releasing %r', self)
137         self.in_use = False
138         self._pool._notify_release()
139
140     async def close(self) -> None:
141         if self._closed:
142             return
143         logging.debug('closing %r', self)
144         self._closed = True
145         self._writer.close()
146         self._pool._remove(self)
147         try:
148             await self._writer.wait_closed()
149         except AttributeError:  # wait_closed is new in 3.7
150             pass
151
152     def __getattr__(self, name: str) -> Any:
153         """All unknown attributes are delegated to the reader and writer"""
154         if self._closed or not self.in_use:
155             raise ValueError("Can't use a closed or unacquired connection")
156         if hasattr(self._reader, name):
157             return getattr(self._reader, name)
158         return getattr(self._writer, name)
159
160     async def __aenter__(self) -> "Connection":
161         if self._closed or not self.in_use:
162             raise ValueError("Can't use a closed or unacquired connection")
163         return self
164
165     async def __aexit__(
166         self,
167         exc_type: Optional[Type[BaseException]],
168         exc: Optional[BaseException],
169         tb: Optional[TracebackType],
170     ) -> None:
171         self.release()
172
173     def __del__(self) -> None:
174         if self._closed:
175             return
176         context = {"connection": self, "message": "Unclosed connection"}
177         self._pool._loop.call_exception_handler(context)