#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# The Qubes OS Project, http://www.qubes-os.org
#
# Copyright (C) 2013-2015  Marek Marczykowski-Górecki
#                                   <marmarek@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
#
import os
import struct
import sys
import asyncio
import argparse
from qubesimgconverter import ICON_MAXSIZE, Image
import logging

import xcffib as xcb
import xcffib.xproto as xproto
from qubesadmin import Qubes
from qubesadmin.exc import QubesException

log = logging.getLogger('icon-receiver')

BYTES_PER_PIXEL = 4
BITS_PER_PIXEL = 32
#: anti-DoS protection for header retrieval - this limit is used for both
#: "window header" (just window ID) and "icon header" (icon dimensions)
#: at most 2 32bit integers coded as decimal number, plus space between them
#: and EOL marker
MAX_HEADER_LENGTH = 2 * len(str(2 << 32)) + 1 + 1

#: max size expected in _QUBES_VMNAME property
MAX_VMNAME_LENGTH = 32

#: how many windows try to retrieve from _NET_CLIENT_LIST property in first try
#: this is just arbitrary number, trying to estimate maximum windows count;
#: if there are more windows, will be retrieved in a second call
INITIAL_CLIENT_LIST_RETRIEVAL_COUNT = 512

# Property format (not type!) used to store strings and other byte streams
X_FORMAT_8 = 8
X_FORMAT_STRING = X_FORMAT_8

# Property format (not type!) used to store 32bit numbers (like window IDs)
X_FORMAT_32 = 32
X_FORMAT_WINDOWID = X_FORMAT_32

SOCKET_PATH = '/var/run/qubes/icon-receiver.sock'


class IconReceiver(object):
    """
    This class is responsible for handling windows icons updates sent from
    the VM. Each received icon is sanitized, tinted to appropriate VM color
    and then set to the local window.

    The protocol is simple, one direction stream (VM->dom0):
     - window ID as decimal number as the only string in the line
     - icon image in qubes.GetImageRGBA format (which is: icon dimentions in
     one line, then pixel stream each coded as 32-bit in RGBA order)
    """
    def __init__(self):
        # Connect to local X server and get basic properties
        self.conn = xcb.connect()
        self.setup = self.conn.get_setup()
        self.root = self.setup.roots[0].root

        # Properties set by gui-daemon on each VM-originated window; we use
        # this to identify which window corresponds to requested VM window
        #: VM name of corresponding window
        self.atom_vmname = self.conn.core.InternAtom(False,
            len("_QUBES_VMNAME"), "_QUBES_VMNAME").reply().atom
        #: remote window ID
        self.atom_remote_id = self.conn.core.InternAtom(False,
            len("_QUBES_VMWINDOWID"), "_QUBES_VMWINDOWID").reply().atom
        #: In case of reparenting window manager, we need to know which
        #: windows are real application, and which are only window frames
        self.atom_net_client_list = self.conn.core.InternAtom(False,
            len("_NET_CLIENT_LIST"), "_NET_CLIENT_LIST").reply().atom
        #: Property holding window icon - we set the received (sanitized and
        #: tinted) icon there
        self.atom_net_wm_icon = self.conn.core.InternAtom(False,
            len("_NET_WM_ICON"), "_NET_WM_ICON").reply().atom

        self.conn.core.ChangeWindowAttributesChecked(
            self.root, xproto.CW.EventMask,
            [xproto.EventMask.SubstructureNotify])
        self.conn.flush()

        #: Cache for remote->local window ID mapping
        self.remote2local_window_map = {}
        #: Cache for local->remote window ID mapping
        self.local2remote_window_map = {}

        #: cache of icons received for not (yet) created windows
        #: each element is a tuple of (winid, icon_data), where icon_data is
        #: already tinted and serialized for window property.
        #: at most 5 elements are stored
        self.icon_cache = []

        self.app = Qubes()

    def get_color(self, domain):
        # Load the VM properties - we need this only to get VM color
        self.app.domains.refresh_cache(force=True)
        try:
            vm = self.app.domains[domain]
        except KeyError:
            raise QubesException("VM '{}' doesn't exist in qubes.xml".format(
                domain))
        return vm.label.color

    @staticmethod
    def _unpack_int32_array(data):
        """
        Convert byte stream to 32bit number array. Used to convert result of
        XGetProperty for 32bit properties (window IDs, etc)
        :param data: byte stream
        :return: 32bit number array
        """
        if data.value_len == 0:
            return []
        if data.format == X_FORMAT_32:
            return struct.unpack("I" * data.value_len, data.value.buf())
        else:
            raise TypeError("Expected format 32")

    def watch_window(self, w):
        self.conn.core.ChangeWindowAttributesChecked(
            w,
            xproto.CW.EventMask,
            [xproto.EventMask.StructureNotify]
        )

    def refresh_windows_mapping(self):
        """
        Enumerate windows and record them.
        This function updates self.local2remote_window_map and
        self.remote2local_window_map. Each time a window is added there,
        additionally its watched for StructureNotify to receive event when
        it's destroyed

        :return: None
        """
        name_queries = {}
        remote_id_queries = {}
        # if embedding window manager is running, client windows are not
        # direct children of root window, so traverse such clients list ...
        cookie = self.conn.core.GetProperty(
            False,  # do not delete property
            self.root,  # window
            self.atom_net_client_list,  # property
            xproto.Atom.WINDOW,  # type
            0,  # offset
            INITIAL_CLIENT_LIST_RETRIEVAL_COUNT)  # length
        client_list_reply = cookie.reply()
        client_list = self._unpack_int32_array(client_list_reply)
        if client_list_reply.bytes_after:
            cookie = self.conn.core.GetProperty(
                False,  # do not delete property
                self.root,  # window
                self.atom_net_client_list,  # property
                xproto.Atom.WINDOW,  # type
                client_list_reply.value_len,  # offset
                client_list_reply.bytes_after)  # length
            client_list_reply = cookie.reply()
            client_list += self._unpack_int32_array(client_list_reply)
        if not client_list:
            # ... otherwise just look at root window children
            cookie = self.conn.core.QueryTree(self.root)
            root_tree = cookie.reply()
            client_list = root_tree.children

        # Now iterate over all application windows. For performance reasons
        # (according to XCB manual), do this in two runs:
        # First issue GetProperty commands (recording "cookies") ...
        for w in client_list:
            if w in self.local2remote_window_map.keys():
                # already cached
                continue

            name_queries[w] = self.conn.core.GetProperty(
                                        False,  # delete
                                        w,  # window
                                        self.atom_vmname,  # property
                                        xproto.Atom.STRING,  # type
                                        0,  # long_offset
                                        MAX_VMNAME_LENGTH  # long_length
                                        )
            remote_id_queries[w] = self.conn.core.GetProperty(
                                        False,  # delete
                                        w,  # window
                                        self.atom_remote_id,  # property
                                        xproto.Atom.WINDOW,  # type
                                        0,  # long_offset
                                        1   # long_length
                                        )

        # ... then retrieve results
        for w in name_queries.keys():
            try:
                vmname = name_queries[w].reply()
                remote_id_reply = remote_id_queries[w].reply()
            except xproto.WindowError:
                continue
            if vmname.format == X_FORMAT_STRING:
                domain = vmname.value.buf().decode()
                # if _QUBES_VMREMOTEID is set, store it in the map,
                # otherwise simply ignore the window - most likely it was
                #  just created and don't have that property yet
                if remote_id_reply.format == X_FORMAT_WINDOWID and \
                   remote_id_reply.value_len:
                    win_remote_id = (domain, self._unpack_int32_array(
                        remote_id_reply)[0])
                    self.remote2local_window_map[win_remote_id] = w
                    self.local2remote_window_map[w] = win_remote_id
                    self.watch_window(w)

    def search_for_window(self, remote_id):
        """
        Search for local window matching given remote window ID. Raise
        KeyError if none exists
        :param remote_id: remote window ID
        :return: local window ID
        """
        # first handle events - remove outdated IDs
        self.handle_pending_events()
        if remote_id not in self.remote2local_window_map:
            self.refresh_windows_mapping()
        # may raise KeyError
        return self.remote2local_window_map[remote_id]

    async def handle_events(self):
        asyncio.get_event_loop().add_reader(
            self.conn.get_file_descriptor(), self.handle_pending_events)

    def handle_pending_events(self):
        """
        Handle X11 events
        - DestroyNotifyEvent:remove the event window from local windows map
        - CreateNotifyEvent: check if any cached icon applies to it
        :return:
        """
        for ev in iter(self.conn.poll_for_event, None):
            if isinstance(ev, xproto.DestroyNotifyEvent):
                try:
                    remote_id = self.local2remote_window_map.pop(ev.window, None)
                    if remote_id is not None:
                        self.remote2local_window_map.pop(remote_id)
                except KeyError:
                    pass
            elif isinstance(ev, xproto.CreateNotifyEvent):
                for remote_winid, icon_property_data in list(self.icon_cache):
                    try:
                        local_winid = self.search_for_window(remote_winid)
                        self.set_icon_for_window(local_winid, icon_property_data)
                        self.icon_cache.remove((remote_winid, icon_property_data))
                    except (KeyError, ValueError):
                        pass

    @staticmethod
    def _convert_rgba_to_argb(rgba_image):
        """
        qubes.GetImageRBGA format uses RGBA byte order, while X11 (in
        _NET_WM_ICON property) use ARGB byte order
        :param rgba_image: pixel array coded as 32bit RGBA number
        :return: pixel array coded as 32bit ARGB number
        """
        pixel_count = len(rgba_image) // BYTES_PER_PIXEL
        return struct.pack(
            "%dI" % pixel_count,
            # move the lowest byte to be the highest
            *[(p >> 8) | ((p & 0xff) << 24) for p in
              struct.unpack(">%dI" % pixel_count, rgba_image)])

    async def retrieve_icon_for_window(self, reader, color):
        # intentionally don't catch exceptions here
        # the Image.get_from_stream method receives UNTRUSTED data
        # from given stream (stdin), sanitize it and store in Image() object
        icon = await Image.get_from_stream_async(reader,
            ICON_MAXSIZE, ICON_MAXSIZE)
        # now we can tint the icon to the VM color
        icon_tinted = icon.tint(color)
        # conver RGBA (Image.data) -> ARGB (X11)
        icon_tinted_data = self._convert_rgba_to_argb(icon_tinted.data)
        # prepare icon header according to X11 _NET_WM_ICON format:
        # "This is an array of 32bit packed CARDINAL ARGB with high byte
        # being A, low byte being B. The first two cardinals are width, height.
        # Data is in rows, left to right and top to bottom."
        # http://standards.freedesktop.org/wm-spec/1.4/ar01s05.html
        icon_property_data = struct.pack(
            "II", icon_tinted.width, icon_tinted.height)
        # and then append the actual icon
        icon_property_data += icon_tinted_data
        return icon_property_data

    def set_icon_for_window(self, window, icon_property_data):
        self.conn.core.ChangeProperty(
            xproto.PropMode.Replace,
            window,
            self.atom_net_wm_icon,  # property
            xproto.Atom.CARDINAL,  # type
            BITS_PER_PIXEL,  # format
            len(icon_property_data) // BYTES_PER_PIXEL,
            icon_property_data)
        self.conn.flush()

    def cache_icon(self, remote_winid, icon_property_data):
        """
        Cache icon
        :return: None
        """
        cache_dict = dict(self.icon_cache)
        if remote_winid in cache_dict:
            self.icon_cache.remove((remote_winid, cache_dict[remote_winid]))
        self.icon_cache.insert(0, (remote_winid, icon_property_data))
        self.icon_cache = self.icon_cache[:5]

    async def handle_clients(self, socket_path=SOCKET_PATH):
        if os.path.exists(socket_path):
            os.unlink(socket_path)
        server = await asyncio.start_unix_server(
            self.handle_client, socket_path)
        await server.serve_forever()

    async def handle_client(self, reader, writer):
        try:
            # Parse header from qrexec
            header = await reader.readuntil(b'\0')
            header_parts = header.decode('ascii').split(' ')
            assert len(header_parts) >= 2, header_parts

            service_name = header_parts[0]
            if '+' in service_name:
                service_name, arg = service_name.split('+', 1)
                assert arg == '', arg
            assert service_name == 'qubes.WindowIconUpdater', service_name

            domain = header_parts[1]
            color = self.get_color(domain)

            log.info('connected: %s', domain)

            while True:
                untrusted_w = await reader.readline()
                if untrusted_w == b'':
                    break
                if len(untrusted_w) > 32:
                    raise ValueError("WindowID too long")
                remote_winid = (domain, int(untrusted_w))
                icon_property_data = await self.retrieve_icon_for_window(
                    reader, color)
                try:
                    local_winid = self.search_for_window(remote_winid)
                    self.set_icon_for_window(local_winid, icon_property_data)
                except KeyError:
                    self.cache_icon(remote_winid, icon_property_data)

            log.info('disconnected: %s', domain)
        finally:
            writer.close()
            await writer.wait_closed()


parser = argparse.ArgumentParser()

parser.add_argument(
    '-f', '--force', action='store_true',
    help='run even if not in GuiVM')


def main():
    args = parser.parse_args()

    if not args.force:
        if (not os.path.exists('/var/run/qubes-service/guivm-gui-agent') and
            not os.path.exists('/etc/qubes-release')):

            print('Not in GuiVM or dom0, exiting '
                  '(run with --force to override)',
                  file=sys.stderr)
            return

    logging.basicConfig(
        stream=sys.stderr, level=logging.INFO,
        format='%(asctime)s %(name)s: %(message)s')

    rcvd = IconReceiver()

    def handle_exception(loop, context):
        e = context.get('exception')
        if isinstance(e, xcb.ConnectionException):
            logging.error("Connection error: %s", e)
            sys.exit(1)

    loop = asyncio.get_event_loop()
    tasks = [
        rcvd.handle_events(),
        rcvd.handle_clients(),
    ]
    loop.set_exception_handler(handle_exception)
    loop.run_until_complete(asyncio.gather(*tasks))


if __name__ == '__main__':
    main()
