# -*- coding: utf-8 -*-
#
# network.py - I/O with WeeChat/relay
#
# Copyright (C) 2011-2022 SĂ©bastien Helleu <flashcode@flashtux.org>
#
# This file is part of QWeeChat, a Qt remote GUI for WeeChat.
#
# QWeeChat is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# QWeeChat is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with QWeeChat.  If not, see <http://www.gnu.org/licenses/>.
#

"""I/O with WeeChat/relay."""

import hashlib
import secrets
import struct

from gi.repository import GObject, GLib, Gio

from weegtk import config

from sshtunnel import SSHTunnelForwarder

# list of supported hash algorithms on our side
# (the hash algorithm will be negotiated with the remote WeeChat)
_HASH_ALGOS_LIST = [
    'plain',
    'sha256',
    'sha512',
    'pbkdf2+sha256',
    'pbkdf2+sha512',
]
_HASH_ALGOS = ':'.join(_HASH_ALGOS_LIST)

# handshake with remote WeeChat (before init)
_PROTO_HANDSHAKE = f'(handshake) handshake password_hash_algo={_HASH_ALGOS}\n'

# initialize with the password (plain text)
_PROTO_INIT_PWD = 'init password=%(password)s%(totp)s\n'  # nosec

# initialize with the hashed password
_PROTO_INIT_HASH = ('init password_hash='
                    '%(algo)s:%(salt)s%(iter)s:%(hash)s%(totp)s\n')

_PROTO_SYNC_CMDS = [
    # get buffers
    '(listbuffers) hdata buffer:gui_buffers(*) number,full_name,short_name,'
    'type,nicklist,title,local_variables',
    # get lines
    '(listlines) hdata buffer:gui_buffers(*)/own_lines/last_line(-%(lines)d)/'
    'data date,displayed,prefix,message',
    # get nicklist for all buffers
    '(nicklist) nicklist',
    # enable synchronization
    'sync',
]

STATUS_DISCONNECTED = 'disconnected'
STATUS_CONNECTING = 'connecting'
STATUS_AUTHENTICATING = 'authenticating'
STATUS_CONNECTED = 'connected'

NETWORK_STATUS = {
    STATUS_DISCONNECTED: {
        'label': 'Disconnected',
        'color': '#aa0000',
        'icon': 'dialog-close.png',
    },
    STATUS_CONNECTING: {
        'label': 'Connecting…',
        'color': '#dd5f00',
        'icon': 'dialog-warning.png',
    },
    STATUS_AUTHENTICATING: {
        'label': 'Authenticating…',
        'color': '#007fff',
        'icon': 'dialog-password.png',
    },
    STATUS_CONNECTED: {
        'label': 'Connected',
        'color': 'green',
        'icon': 'dialog-ok-apply.png',
    },
}


class Network(GObject.GObject):
    """I/O with WeeChat/relay."""
    __gsignals__ = {
        "status_changed" : (
            GObject.SignalFlags.RUN_FIRST,
            None,
            (str, str)
        ),
        "message_from_weechat" : (
            GObject.SignalFlags.RUN_FIRST,
            None,
            (GLib.Bytes,)
        )
    }

    def __init__(self, *args):
        super().__init__(*args)
        self._init_connection()
        self.debug_lines = []
        self.debug_dialog = None
        self._lines = config.CONFIG_DEFAULT_RELAY_LINES
        self.cancel_network_reads = Gio.Cancellable()
        self._buffer = bytearray()
        self._socketclient = Gio.SocketClient.new()
        self._socket = None

        # TODO figure out how to deal with these signals
        #self._socket.connected.self._socket_connected)
        #self._socket.readyRead.connect(self._socket_read)
        #self._socket.disconnected.connect(self._socket_disconnected)

    def _init_connection(self):
        self.status = STATUS_DISCONNECTED
        self._hostname = None
        self._port = None
        self._ssl = None
        self._password = None
        self._totp = None
        self._handshake_received = False
        self._handshake_timer = None
        self._pwd_hash_algo = None
        self._pwd_hash_iter = 0
        self._server_nonce = None

    def set_status(self, status):
        """Set current status."""
        self.status = status
        self.emit("status_changed", status, None)

    def pbkdf2(self, hash_name, salt):
        """Return hashed password with PBKDF2-HMAC."""
        return hashlib.pbkdf2_hmac(
            hash_name,
            password=self._password.encode('utf-8'),
            salt=salt,
            iterations=self._pwd_hash_iter,
        ).hex()

    def _build_init_command(self):
        """Build the init command to send to WeeChat."""
        totp = f',totp={self._totp}' if self._totp else ''
        if self._pwd_hash_algo == 'plain':  # nosec
            cmd = _PROTO_INIT_PWD % {
                'password': self._password,
                'totp': totp,
            }
        else:
            client_nonce = secrets.token_bytes(16)
            salt = self._server_nonce + client_nonce
            pwd_hash = None
            iterations = ''
            if self._pwd_hash_algo == 'pbkdf2+sha512':  # nosec
                pwd_hash = self.pbkdf2('sha512', salt)
                iterations = f':{self._pwd_hash_iter}'
            elif self._pwd_hash_algo == 'pbkdf2+sha256':  # nosec
                pwd_hash = self.pbkdf2('sha256', salt)
                iterations = f':{self._pwd_hash_iter}'
            elif self._pwd_hash_algo == 'sha512':  # nosec
                pwd = salt + self._password.encode('utf-8')
                pwd_hash = hashlib.sha512(pwd).hexdigest()
            elif self._pwd_hash_algo == 'sha256':  # nosec
                pwd = salt + self._password.encode('utf-8')
                pwd_hash = hashlib.sha256(pwd).hexdigest()
            if not pwd_hash:
                return None
            cmd = _PROTO_INIT_HASH % {
                'algo': self._pwd_hash_algo,
                'salt': bytearray(salt).hex(),
                'iter': iterations,
                'hash': pwd_hash,
                'totp': totp,
            }
        return cmd

    def _build_sync_command(self):
        """Build the sync commands to send to WeeChat."""
        cmd = '\n'.join(_PROTO_SYNC_CMDS) + '\n'
        return cmd % {'lines': self._lines}

    def handshake_timer_expired(self):
        if self.status == STATUS_AUTHENTICATING:
            self._pwd_hash_algo = 'plain'  # nosec
            self.send_to_weechat(self._build_init_command())
            self.sync_weechat()
            self.set_status(STATUS_CONNECTED)
        return False

    def _socket_connected(self):
        """Slot: socket connected."""
        self.set_status(STATUS_AUTHENTICATING)
        self.send_to_weechat(_PROTO_HANDSHAKE)
        self._handshake_timer = GLib.timeout_add(2000, self.handshake_timer_expired)

    def _socket_read(self, source_object, res, *user_data):
        """Slot: data available on socket."""
        try:
            gbytes = self.input.read_bytes_finish(res)
        except GLib.GError as err:
            self.handle_network_error(err)
            return

        self._buffer.extend(gbytes.get_data())
        while len(self._buffer) >= 4:
            remainder = None
            length = struct.unpack('>i', self._buffer[0:4])[0]
            if len(self._buffer) < length:
                # partial message, just wait for end of message
                break
            # more than one message?
            if length < len(self._buffer):
                # save beginning of another message
                remainder = self._buffer[length:]
                self._buffer = self._buffer[0:length]
            self.emit("message_from_weechat", GLib.Bytes(self._buffer))
            if not self.is_connected():
                return
            self._buffer.clear()
            if remainder:
                self._buffer.extend(remainder)

        self.input.read_bytes_async(
            4096, 0, self.cancel_network_reads, self._socket_read)

    def _socket_disconnected(self):
        """Slot: socket disconnected."""
        if self._handshake_timer:
            self._handshake_timer.stop()
        self._init_connection()
        self.set_status(STATUS_DISCONNECTED)

    def is_connected(self):
        """Return True if the socket is connected, False otherwise."""
        if self._socket is None:
            return False
        return self._socket.is_connected()

    def is_ssl(self):
        """Return True if SSL is used, False otherwise."""
        return self._ssl

    def connect_weechat_ssh(self, ssh_host, ssh_port, ssh_username, ssh_key,
                            relay_host, relay_port, relay_pw):
        self.ssh_tunnel = SSHTunnelForwarder(
            (ssh_host, int(ssh_port)),
            ssh_username=ssh_username,
            ssh_pkey=ssh_key,
            remote_bind_address=(relay_host, int(relay_port)))

        self.ssh_tunnel.start()

        self.connect_weechat("localhost", self.ssh_tunnel.local_bind_port,
                             False, relay_pw, None, "")

    def connect_weechat(self, hostname, port, ssl, password, totp, lines):
        """Connect to WeeChat."""
        self._hostname = hostname
        try:
            self._port = int(port)
        except ValueError:
            self._port = 0
        self._ssl = ssl
        self._password = password
        self._totp = totp
        try:
            self._lines = int(lines)
        except ValueError:
            self._lines = config.CONFIG_DEFAULT_RELAY_LINES

        if self.cancel_network_reads.is_cancelled():
            self.cancel_network_reads.reset()

        # TODO handle SSL
        self._socketclient.connect_async(
            Gio.NetworkAddress.new(self._hostname, self._port),
            None,
            self._connected_func, None)
        self.set_status(STATUS_CONNECTING)

    def _connected_func(self, source_object, res, *user_data):
        """Callback function called after connection attempt."""
        try:
            self._socket = self._socketclient.connect_finish(res)
        except GLib.Error as err:
            print("Connection failed:\n{}".format(err.message))
            self.set_status(STATUS_DISCONNECTED)
            return
        else:
            print("Connected")
            self.set_status(STATUS_CONNECTED)
            self._socket_connected()
            self.input = self._socket.get_input_stream()
            self.input.read_bytes_async(
                4096, 0, self.cancel_network_reads, self._socket_read)

    def disconnect_weechat(self):
        """Disconnect from WeeChat."""
        if not self._socket.is_connected():
            return
        else:
            self.send_to_weechat('quit\n')
            self._socket.set_graceful_disconnect(True)
            self._socket.close()
            self._socket = None
            self.cancel_network_reads.cancel()
            self.ssh_tunnel.stop()
            self.ssh_tunnel = None
            self.set_status(STATUS_DISCONNECTED)

    def handle_network_error(self, err):
        if err.matches(Gio.io_error_quark(), Gio.IOErrorEnum.CANCELLED):
            print("Connection has been canceled by user.")
            return
        elif err.matches(Gio.tls_error_quark(), Gio.TlsError.EOF):
            print("Server has closed the connection.")
            return
        elif err.matches(Gio.io_error_quark(), Gio.IOErrorEnum.BROKEN_PIPE):
            print("Broken pipe, connection lost.")
            return
        elif err.matches(Gio.io_error_quark(), Gio.IOErrorEnum.TIMED_OUT):
            print("Connection timed out.")
            return
        else:
            raise

    def send_to_weechat(self, message):
        """Send a message to WeeChat."""
        output = self._socket.get_output_stream()
        try:
            output.write(message.encode("utf-8"))
        except GLib.Error as err:
            self.handle_network_error(err)

    def init_with_handshake(self, response):
        """Initialize with WeeChat using the handshake response."""
        self._pwd_hash_algo = response['password_hash_algo']
        self._pwd_hash_iter = int(response['password_hash_iterations'])
        self._server_nonce = bytearray.fromhex(response['nonce'])
        if self._pwd_hash_algo:
            cmd = self._build_init_command()
            if cmd:
                self.send_to_weechat(cmd)
                self.sync_weechat()
                self.set_status(STATUS_CONNECTED)
                return
        # failed to initialize: disconnect
        self.disconnect_weechat()

    def desync_weechat(self):
        """Desynchronize from WeeChat."""
        self.send_to_weechat('desync\n')

    def sync_weechat(self):
        """Synchronize with WeeChat."""
        self.send_to_weechat(self._build_sync_command())

    def status_label(self, status):
        """Return the label for a given status."""
        return NETWORK_STATUS.get(status, {}).get('label', '')

    def status_color(self, status):
        """Return the color for a given status."""
        return NETWORK_STATUS.get(status, {}).get('color', 'black')

    def status_icon(self, status):
        """Return the name of icon for a given status."""
        return NETWORK_STATUS.get(status, {}).get('icon', '')

    def get_options(self):
        """Get connection options."""
        return {
            'hostname': self._hostname,
            'port': self._port,
            'ssl': 'on' if self._ssl else 'off',
            'password': self._password,
            'lines': str(self._lines),
        }

    def debug_print(self, *args, **kwargs):
        """Display a debug message."""
        self.debug_lines.append((args, kwargs))
        if self.debug_dialog:
            self.debug_dialog.chat.display(*args, **kwargs)

    def _debug_dialog_closed(self, result):
        """Called when debug dialog is closed."""
        self.debug_dialog = None

    def debug_input_text_sent(self, text):
        """Send debug buffer input to WeeChat."""
        if self.network.is_connected():
            text = str(text)
            pos = text.find(')')
            if text.startswith('(') and pos >= 0:
                text = '(debug_%s)%s' % (text[1:pos], text[pos+1:])
            else:
                text = '(debug) %s' % text
            self.network.debug_print(0, '<==', text, forcecolor='#AA0000')
            self.network.send_to_weechat(text + '\n')

    def open_debug_dialog(self):
        """Open a dialog with debug messages."""
        if not self.debug_dialog:
            self.debug_dialog = DebugDialog()
            self.debug_dialog.input.textSent.connect(
                self.debug_input_text_sent)
            self.debug_dialog.finished.connect(self._debug_dialog_closed)
            self.debug_dialog.display_lines(self.debug_lines)
            self.debug_dialog.chat.scroll_bottom()