from .utils import Range, escape
from .const import CONNECTION_MAX, CONNECTION_TIMEOUT
-from .const import ERROR_PREFIX, SUCCESS, NEXT
+from .const import HELLO_PREFIX, ERROR_PREFIX, SUCCESS, NEXT
log = logging.getLogger(__name__)
password: str | None = None):
#: Connection pool
self._pool = ConnectionPool(max_connections=CONNECTION_MAX)
+ #: connection timeout
+ self.mpd_timeout = CONNECTION_TIMEOUT
self._get_envvars()
#: Host used to make connections (:py:obj:`str`)
self.host = host or self.server_discovery[0]
self.pwd = password or self.server_discovery[2]
#: port used with the current connection (:py:obj:`int`, :py:obj:`str`)
self.port = port or self.server_discovery[1]
- #: connection timeout
- self.mpd_timeout = CONNECTION_TIMEOUT
log.info('Using %s:%s to connect', self.host, self.port)
def _get_envvars(self):
return retval
async def _init_connection(self):
- """Init connection if needed"""
+ """Init connection if needed
+
+ * Consumes the hello line and sets the protocol version
+ * Send password command if a password is provided
+ """
if not self.connection.version:
- # TODO: move hello here instead of connection?
- # Need to consume hello
- pass
+ await self._hello()
if self.password and not self.connection.auth:
# Need to send password
await self._write_command('password', [self.password])
await self._fetch_nothing()
self.connection.auth = True
+ async def _hello(self) -> None:
+ """Consume HELLO_PREFIX"""
+ data = await self.connection.readuntil(b'\n')
+ rcv = data.decode('utf-8')
+ if not rcv.startswith(HELLO_PREFIX):
+ raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
+ log.debug('consumed hello prefix: %r', rcv)
+ self.connection.version = rcv.split('\n')[0][len(HELLO_PREFIX):]
+
async def _write_line(self, line):
self.connection.write(f"{line!s}\n".encode())
await self.connection.drain()
async def _fetch_nothing(self):
line = await self._read_line()
if line is not None:
- raise ProtocolError(f"Got unexpected return value: '{line}'")
+ raise MPDProtocolError(f"Got unexpected return value: '{line}'")
async def _fetch_item(self):
pairs = [_ async for _ in self._read_pairs()]