# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2012-2024  kaliko <kaliko@azylum.org>
# SPDX-FileCopyrightText: 2008-2010  J. Alexander Treuman <jat@spatialrift.net>
# SPDX-License-Identifier: LGPL-3.0-or-later

import logging
import os

from .connection import ConnectionPool, Connection
from .exceptions import MPDConnectionError, MPDProtocolError, MPDCommandError
from .utils import Range, escape

from .const import CONNECTION_MAX, CONNECTION_TIMEOUT
from .const import HELLO_PREFIX, ERROR_PREFIX, SUCCESS, NEXT

log = logging.getLogger(__name__)


class MPDClient:
    """:synopsis: Main class to instanciate building an MPD client.

    :param host: MPD server IP|FQDN to connect to
    :param port: MPD port to connect to
    :param password: MPD password

    **musicpdaio** tries to come with sane defaults, then running
    |mpdaio.MPDClient| with no explicit argument will try default values
    to connect to MPD. Cf. :ref:`reference` for more about
    :ref:`defaults<default_settings>`.

    The class is also exposed in mpdaio namespace.

    >>> import mpdaio
    >>> cli = mpdaio.MPDClient(host='example.org')
    >>> print(await cli.currentsong())
    >>> await cli.close()
    """

    def __init__(self, host: str | None = None,
                 port: str | int | None = None,
                 password: str | None = None):
        #: Connection pool
        self._pool = ConnectionPool(max_connections=CONNECTION_MAX)
        #: connection timeout
        self.mpd_timeout = CONNECTION_TIMEOUT
        self._get_envvars()
        #: Host used to make connections (:py:obj:`str`)
        self.host = host or self.server_discovery[0]
        #: password used to connect (:py:obj:`str`)
        self.pwd = password or self.server_discovery[2]
        #: port used with the current connection (:py:obj:`int`, :py:obj:`str`)
        self.port = port or self.server_discovery[1]
        log.info('Using %s:%s to connect', self.host, self.port)

    def _get_envvars(self):
        """
        Retrieve MPD env. var. to overrides default "localhost:6600"
        """
        # Set some defaults
        disco_host = 'localhost'
        disco_port = os.getenv('MPD_PORT', '6600')
        pwd = None
        _host = os.getenv('MPD_HOST', '')
        if _host:
            # If password is set: MPD_HOST=pass@host
            if '@' in _host:
                mpd_host_env = _host.split('@', 1)
                if mpd_host_env[0]:
                    # A password is actually set
                    log.debug(
                        'password detected in MPD_HOST, set client pwd attribute')
                    pwd = mpd_host_env[0]
                    if mpd_host_env[1]:
                        disco_host = mpd_host_env[1]
                        log.debug('host detected in MPD_HOST: %s', disco_host)
                elif mpd_host_env[1]:
                    # No password set but leading @ is an abstract socket
                    disco_host = '@'+mpd_host_env[1]
                    log.debug(
                        'host detected in MPD_HOST: %s (abstract socket)', disco_host)
            else:
                # MPD_HOST is a plain host
                disco_host = _host
                log.debug('host detected in MPD_HOST: %s', disco_host)
        else:
            # Is socket there
            xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR', '/run')
            rundir = os.path.join(xdg_runtime_dir, 'mpd/socket')
            if os.path.exists(rundir):
                disco_host = rundir
                log.debug(
                    'host detected in ${XDG_RUNTIME_DIR}/run: %s (unix socket)', disco_host)
        _mpd_timeout = os.getenv('MPD_TIMEOUT', '')
        if _mpd_timeout.isdigit():
            self.mpd_timeout = int(_mpd_timeout)
            log.debug('timeout detected in MPD_TIMEOUT: %d', self.mpd_timeout)
        else:  # Use CONNECTION_TIMEOUT as default even if MPD_TIMEOUT carries gargage
            self.mpd_timeout = CONNECTION_TIMEOUT
        self.server_discovery = (disco_host, disco_port, pwd)

    def __getattr__(self, attr):
        command = attr
        wrapper = CmdHandler(self._pool, self.host, self.port, self.pwd, self.mpd_timeout)
        if command not in wrapper._commands:
            command = command.replace("_", " ")
            if command not in wrapper._commands:
                raise AttributeError(
                    f"'CmdHandler' object has no attribute '{attr}'")
        return lambda *args: wrapper(command, args)

    @property
    def version(self) -> str:
        """MPD protocol version (returns an empty string if there is no opened connection)
        """
        host = (self.host, self.port)
        version = {_.version for _ in self.connections}
        if not version:
            log.warning('No connections yet in the connections pool for %s', host)
            return ''
        if len(version) > 1:
            log.warning('More than one version in the connections pool for %s', host)
        return version.pop()

    @property
    def connections(self) -> list[Connection]:
        """connections in the pool"""
        host = (self.host, self.port)
        return self._pool._connections.get(host, [])

    async def close(self) -> None:
        """:synopsis: Close connections in the pool"""
        await self._pool.close()


class CmdHandler:

    def __init__(self, pool, server, port, password, timeout):
        self._commands = {
            # Status Commands
            "clearerror":         self._fetch_nothing,
            "currentsong":        self._fetch_object,
            "idle":               self._fetch_list,
            "noidle":             self._fetch_nothing,
            "status":             self._fetch_object,
            "stats":              self._fetch_object,
            # Playback Option Commands
            "consume":            self._fetch_nothing,
            "crossfade":          self._fetch_nothing,
            "mixrampdb":          self._fetch_nothing,
            "mixrampdelay":       self._fetch_nothing,
            "random":             self._fetch_nothing,
            "repeat":             self._fetch_nothing,
            "setvol":             self._fetch_nothing,
            "getvol":             self._fetch_object,
            "single":             self._fetch_nothing,
            "replay_gain_mode":   self._fetch_nothing,
            "replay_gain_status": self._fetch_item,
            "volume":             self._fetch_nothing,
            # Playback Control Commands
            "next":               self._fetch_nothing,
            "pause":              self._fetch_nothing,
            "play":               self._fetch_nothing,
            "playid":             self._fetch_nothing,
            "previous":           self._fetch_nothing,
            "seek":               self._fetch_nothing,
            "seekid":             self._fetch_nothing,
            "seekcur":            self._fetch_nothing,
            "stop":               self._fetch_nothing,
            # Queue Commands
            "add":                self._fetch_nothing,
            "addid":              self._fetch_item,
            "clear":              self._fetch_nothing,
            "delete":             self._fetch_nothing,
            "deleteid":           self._fetch_nothing,
            "move":               self._fetch_nothing,
            "moveid":             self._fetch_nothing,
            "playlist":           self._fetch_playlist,
            "playlistfind":       self._fetch_songs,
            "playlistid":         self._fetch_songs,
            "playlistinfo":       self._fetch_songs,
            "playlistsearch":     self._fetch_songs,
            "plchanges":          self._fetch_songs,
            "plchangesposid":     self._fetch_changes,
            "prio":               self._fetch_nothing,
            "prioid":             self._fetch_nothing,
            "rangeid":            self._fetch_nothing,
            "shuffle":            self._fetch_nothing,
            "swap":               self._fetch_nothing,
            "swapid":             self._fetch_nothing,
            "addtagid":           self._fetch_nothing,
            "cleartagid":         self._fetch_nothing,
            # Stored Playlist Commands
            "listplaylist":       self._fetch_list,
            "listplaylistinfo":   self._fetch_songs,
            "listplaylists":      self._fetch_playlists,
            "load":               self._fetch_nothing,
            "playlistadd":        self._fetch_nothing,
            "playlistclear":      self._fetch_nothing,
            "playlistdelete":     self._fetch_nothing,
            "playlistlength":     self._fetch_object,
            "playlistmove":       self._fetch_nothing,
            "rename":             self._fetch_nothing,
            "rm":                 self._fetch_nothing,
            "save":               self._fetch_nothing,
            # Database Commands
            "albumart":           self._fetch_composite,
            "count":              self._fetch_object,
            "getfingerprint":     self._fetch_object,
            "find":               self._fetch_songs,
            "findadd":            self._fetch_nothing,
            "list":               self._fetch_list,
            "listall":            self._fetch_database,
            "listallinfo":        self._fetch_database,
            "listfiles":          self._fetch_database,
            "lsinfo":             self._fetch_database,
            "readcomments":       self._fetch_object,
            "readpicture":        self._fetch_composite,
            "search":             self._fetch_songs,
            "searchadd":          self._fetch_nothing,
            "searchaddpl":        self._fetch_nothing,
            "searchcount":        self._fetch_object,
            "update":             self._fetch_item,
            "rescan":             self._fetch_item,
            # Mounts and neighbors
            "mount":              self._fetch_nothing,
            "unmount":            self._fetch_nothing,
            "listmounts":         self._fetch_mounts,
            "listneighbors":      self._fetch_neighbors,
            # Sticker Commands
            "sticker get":        self._fetch_item,
            "sticker set":        self._fetch_nothing,
            "sticker delete":     self._fetch_nothing,
            "sticker list":       self._fetch_list,
            "sticker find":       self._fetch_songs,
            "stickernames":       self._fetch_list,
            # Connection Commands
            "close":              None,
            "kill":               None,
            "password":           self._fetch_nothing,
            "ping":               self._fetch_nothing,
            "binarylimit":        self._fetch_nothing,
            "tagtypes":           self._fetch_list,
            "tagtypes disable":   self._fetch_nothing,
            "tagtypes enable":    self._fetch_nothing,
            "tagtypes clear":     self._fetch_nothing,
            "tagtypes all":       self._fetch_nothing,
            # Partition Commands
            "partition":          self._fetch_nothing,
            "listpartitions":     self._fetch_list,
            "newpartition":       self._fetch_nothing,
            "delpartition":       self._fetch_nothing,
            "moveoutput":         self._fetch_nothing,
            # Audio Output Commands
            "disableoutput":      self._fetch_nothing,
            "enableoutput":       self._fetch_nothing,
            "toggleoutput":       self._fetch_nothing,
            "outputs":            self._fetch_outputs,
            "outputset":          self._fetch_nothing,
            # Reflection Commands
            "config":             self._fetch_object,
            "commands":           self._fetch_list,
            "notcommands":        self._fetch_list,
            "urlhandlers":        self._fetch_list,
            "decoders":           self._fetch_plugins,
            # Client to Client
            "subscribe":          self._fetch_nothing,
            "unsubscribe":        self._fetch_nothing,
            "channels":           self._fetch_list,
            "readmessages":       self._fetch_messages,
            "sendmessage":        self._fetch_nothing,
        }
        self.command = None
        self._command_list: list | None = None
        self.args = None
        self.pool = pool
        self.host = (server, port)
        self.password = password
        self.timeout = timeout
        #: current connection
        self.connection: [None, Connection] = None

    def __repr__(self):
        args = [str(_) for _ in self.args]
        args = ','.join(args or [])
        return f'{self.command}({args})'

    async def __call__(self, command: str, args: list | None):
        server, port = self.host
        self.command = command
        self.args = args or ''
        self.connection = await self.pool.connect(server, port, self.timeout)
        async with self.connection:
            await self._init_connection()
            retval = self._commands[command]
            await self._write_command(command, args)
            if callable(retval):
                return await retval()
            return retval

    async def _init_connection(self):
        """Init connection if needed

        * Consumes the hello line and sets the protocol version
        * Send password command if a password is provided
        """
        if not self.connection.version:
            await self._hello()
        if self.password and not self.connection.auth:
            # Need to send password
            await self._write_command('password', [self.password])
            await self._fetch_nothing()
            self.connection.auth = True

    async def _hello(self) -> None:
        """Consume HELLO_PREFIX"""
        data = await self.connection.readuntil(b'\n')
        rcv = data.decode('utf-8')
        if not rcv.startswith(HELLO_PREFIX):
            raise MPDProtocolError(f'Got invalid MPD hello: "{rcv}"')
        log.debug('consumed hello prefix: %r', rcv)
        self.connection.version = rcv.split('\n')[0][len(HELLO_PREFIX):]

    async def _write_line(self, line):
        self.connection.write(f"{line!s}\n".encode())
        await self.connection.drain()

    async def _write_command(self, command, args=None):
        if args is None:
            args = []
        parts = [command]
        for arg in args:
            if isinstance(arg, tuple):
                parts.append(f'{Range(arg)!s}')
            else:
                parts.append(f'"{escape(str(arg))}"')
        if '\n' in ' '.join(parts):
            raise MPDCommandError('new line found in the command!')
        #log.debug(' '.join(parts))
        await self._write_line(' '.join(parts))

    async def _read_binary(self, amount):
        chunk = bytearray()
        while amount > 0:
            result = await self.connection.read(amount)
            if len(result) == 0:
                await self.connection.close()
                raise ConnectionError(
                    "Connection lost while reading binary content")
            chunk.extend(result)
            amount -= len(result)
        return bytes(chunk)

    async def _read_line(self):
        line = await self.connection.readline()
        line = line.decode('utf-8')
        if not line.endswith('\n'):
            await self.connection.close()
            raise MPDConnectionError("Connection lost while reading line")
        line = line.rstrip('\n')
        if line.startswith(ERROR_PREFIX):
            error = line[len(ERROR_PREFIX):].strip()
            raise MPDCommandError(error)
        if self._command_list is not None:
            if line == NEXT:
                return None
            if line == SUCCESS:
                raise MPDProtocolError(f"Got unexpected '{SUCCESS}'")
        elif line == SUCCESS:
            return None
        return line

    async def _read_pair(self, separator):
        line = await self._read_line()
        if line is None:
            return None
        pair = line.split(separator, 1)
        if len(pair) < 2:
            raise MPDProtocolError(f"Could not parse pair: '{line}'")
        return pair

    async def _read_pairs(self, separator=": "):
        pair = await self._read_pair(separator)
        while pair:
            yield pair
            pair = await self._read_pair(separator)

    async def _read_list(self):
        seen = None
        async for key, value in self._read_pairs():
            if key != seen:
                if seen is not None:
                    raise MPDProtocolError(
                        f"Expected key '{seen}', got '{key}'")
                seen = key
            yield value

    async def _read_playlist(self):
        async for _, value in self._read_pairs(":"):
            yield value

    async def _read_objects(self, delimiters=None):
        obj = {}
        if delimiters is None:
            delimiters = []
        async for key, value in self._read_pairs():
            key = key.lower()
            if obj:
                if key in delimiters:
                    yield obj
                    obj = {}
                elif key in obj:
                    if not isinstance(obj[key], list):
                        obj[key] = [obj[key], value]
                    else:
                        obj[key].append(value)
                    continue
            obj[key] = value
        if obj:
            yield obj

    async def _read_command_list(self):
        try:
            for retval in self._command_list:
                yield retval()
        finally:
            self._command_list = None
        await self._fetch_nothing()

    async def _fetch_nothing(self):
        line = await self._read_line()
        if line is not None:
            raise MPDProtocolError(f"Got unexpected return value: '{line}'")

    async def _fetch_item(self):
        pairs = [_ async for _ in self._read_pairs()]
        if len(pairs) != 1:
            return None
        return pairs[0][1]

    async def _fetch_list(self):
        return [_ async for _ in self._read_list()]

    async def _fetch_playlist(self):
        return [_ async for _ in self._read_pairs(':')]

    async def _fetch_object(self):
        objs = [obj async for obj in self._read_objects()]
        if not objs:
            return {}
        return objs[0]

    async def _fetch_objects(self, delimiters):
        return [_ async for _ in self._read_objects(delimiters)]

    async def _fetch_changes(self):
        return await self._fetch_objects(["cpos"])

    async def _fetch_songs(self):
        return await self._fetch_objects(["file"])

    async def _fetch_playlists(self):
        return await self._fetch_objects(["playlist"])

    async def _fetch_database(self):
        return await self._fetch_objects(["file", "directory", "playlist"])

    async def _fetch_outputs(self):
        return await self._fetch_objects(["outputid"])

    async def _fetch_plugins(self):
        return await self._fetch_objects(["plugin"])

    async def _fetch_messages(self):
        return await self._fetch_objects(["channel"])

    async def _fetch_mounts(self):
        return await self._fetch_objects(["mount"])

    async def _fetch_neighbors(self):
        return await self._fetch_objects(["neighbor"])

    async def _fetch_composite(self):
        obj = {}
        async for key, value in self._read_pairs():
            key = key.lower()
            obj[key] = value
            if key == 'binary':
                break
        if not obj:
            # If the song file was recognized, but there is no picture, the
            # response is successful, but is otherwise empty.
            return obj
        amount = int(obj['binary'])
        try:
            obj['data'] = await self._read_binary(amount)
        except IOError as err:
            raise ConnectionError(
                f'Error reading binary content: {err}') from err
        data_bytes = len(obj['data'])
        if data_bytes != amount:  # can we ever get there?
            raise ConnectionError('Error reading binary content: '
                                  f'Expects {amount}B, got {data_bytes}')
        # Fetches trailing new line
        await self._read_line()
        #ALT: await self.connection.readuntil(b'\n')
        # Fetches SUCCESS code
        await self._read_line()
        #ALT: await self.connection.readuntil(b'OK\n')
        return obj

    async def _fetch_command_list(self):
        return [_ async for _ in self._read_command_list()]
