]> kaliko git repositories - python-musicpdaio.git/blobdiff - mpdaio/connection.py
Add unix socket (closes #1)
[python-musicpdaio.git] / mpdaio / connection.py
index d6f025b26fb63eedab8f989edb3a68362dc7d002..c47cd7ea3f52705d49ca50f6ba40b0442278f1e9 100644 (file)
@@ -12,6 +12,8 @@ from collections import OrderedDict
 from types import TracebackType
 from typing import Any, List, Optional, Tuple, Type
 
+from . import HELLO_PREFIX
+from .exceptions import MPDProtocolError
 
 try:  # Python 3.7
     base = contextlib.AbstractAsyncContextManager
@@ -37,7 +39,7 @@ class ConnectionPool(base):
                                        List["Connection"]] = OrderedDict()
         self._semaphore = asyncio.Semaphore(max_connections)
 
-    async def connect(self, server: Server, port: Port) -> "Connection":
+    async def connect(self, server: Server, port: Port, timeout: int) -> "Connection":
         host = (server, port)
 
         # enforce the connection limit, releasing connections notifies
@@ -49,6 +51,8 @@ class ConnectionPool(base):
         # find an un-used connection for this host
         connection = next(
             (conn for conn in connections if not conn.in_use), None)
+        #if connection:
+        #    log.debug('reusing %s', connection)
         if connection is None:
             # disconnect the least-recently-used un-used connection to make space
             # for a new connection. There will be at least one.
@@ -57,14 +61,27 @@ class ConnectionPool(base):
                     if not conn.in_use:
                         await conn.close()
                         break
-
-            reader, writer = await asyncio.open_connection(server, port)
+            if server[0] in ['/', '@']:
+                log.debug('about to connect unix socket %s', server)
+                reader, writer = await asyncio.wait_for(
+                        asyncio.open_unix_connection(path=server),
+                        timeout
+                        )
+            else:
+                log.debug('about to connect tcp socket %s:%s', *host)
+                reader, writer = await asyncio.wait_for(
+                        asyncio.open_connection(server, port),
+                        timeout
+                        )
+            #log.debug('Connected to %s:%s', host[0], host[1])
             connection = Connection(self, host, reader, writer)
+            await connection._hello()
             connections.append(connection)
 
         connection.in_use = True
         # move current host to the front as most-recently used
         self._connections.move_to_end(host, False)
+        log.debug('connection %s in use', connection)
 
         return connection
 
@@ -72,6 +89,7 @@ class ConnectionPool(base):
         """Close all connections"""
         connections = [c for cs in self._connections.values() for c in cs]
         self._connections = OrderedDict()
+        log.info('Closing all connections')
         for connection in connections:
             await connection.close()
 
@@ -126,7 +144,7 @@ class Connection(base):
 
     def __repr__(self):
         host = f"{self._host[0]}:{self._host[1]}"
-        return f"Connection<{host}>"
+        return f"Connection<{host}>#{id(self)}"
 
     @property
     def closed(self):
@@ -149,6 +167,17 @@ class Connection(base):
         except AttributeError:  # wait_closed is new in 3.7
             pass
 
+    async def _hello(self) -> None:
+        """Consume HELLO_PREFIX"""
+        self.in_use = True
+        data = await self.readuntil(b'\n')
+        rcv = data.decode('utf-8')
+        if not rcv.startswith(HELLO_PREFIX):
+            raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
+        log.debug('consumed hello prefix: %r', rcv)
+        self.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
+        log.info('protocol version: %s', self.version)
+
     def __getattr__(self, name: str) -> Any:
         """All unknown attributes are delegated to the reader and writer"""
         if self._closed or not self.in_use: