import json
import logging
import os
import socket
import subprocess
from multiprocessing import Pool

HOSTNAME = socket.gethostname()
CLUSTER_FILENAME = "cluster.json"


def kill_process_pool(args):
    kill_fn, node, process = args
    kill_fn(node, process)


def default_kill_fn(node, process):
    return kill_process(node, process["pid"])


class Cluster:
    @staticmethod
    def deserialize(file):
        data = json.load(file)
        return Cluster(data["workdir"], data["nodes"])

    def __init__(self, workdir, nodes=None):
        if nodes is None:
            nodes = {}

        self.workdir = workdir
        self.nodes = nodes

    def add(self, node, pid, cmd, key=None, **kwargs):
        data = {
            "cmd": cmd,
            "pid": pid
        }
        if key:
            data["key"] = key
        data.update(kwargs)

        self.nodes.setdefault(node, []).append(data)

    def processes(self):
        for (node, processes) in self.nodes.items():
            for process in processes:
                yield (node, process)

    def kill(self, kill_fn=None):
        if kill_fn is None:
            kill_fn = default_kill_fn

        with Pool() as pool:
            pool.map(kill_process_pool, [(kill_fn, node, process) for (node, process) in self.processes()])

    def get_processes_by_key(self, key):
        def gen():
            for (node, process) in self.processes():
                if process.get("key") == key:
                    yield (node, process)
        return list(gen())

    def get_monitor_for_node(self, node):
        processes = self.nodes.get(node, [])
        for process in processes:
            if process.get("key") == "monitor":
                return process
        return None

    def serialize(self, file):
        json.dump({
            "workdir": self.workdir,
            "nodes": self.nodes
        }, file, indent=2)


def is_local(host):
    return host == HOSTNAME or host == "localhost" or host == socket.gethostbyname(HOSTNAME)


def start_process(commands, host=None, workdir=None, name=None, env=None, init_cmd=""):
    if not workdir:
        workdir = os.getcwd()
    workdir = os.path.abspath(workdir)

    if init_cmd:
        init_cmd = f"{init_cmd} || exit 1"

    args = []
    if env:
        args += ["env"]
        for (key, val) in env.items():
            args += [f"{key}={val}"]

    args += [str(cmd) for cmd in commands]

    if not name:
        name = "process"
    output = os.path.join(workdir, name)

    logging.info(f"Running {' '.join(str(c) for c in commands)} on {host}")

    stdout_file = f"{output}.out"
    stderr_file = f"{output}.err"
    command = f"""
cd {workdir} || exit 1
{init_cmd}
ulimit -c unlimited
{' '.join(args)} > {stdout_file} 2> {stderr_file} &
ps -ho pgid $!
""".strip()

    cmd_args = []
    if host:
        cmd_args += ["ssh", host]
    else:
        cmd_args += ["setsid"]
    cmd_args += ["/bin/bash"]
    process = subprocess.Popen(cmd_args, cwd=workdir, stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE, stdin=subprocess.PIPE)
    out, err = process.communicate(command.encode())
    pid = out.strip()

    if not pid:
        logging.error(
            f"Process startup failed with status: {process.returncode}, stderr: {err.decode()}, stdout: {out.decode()}")
        if os.path.isfile(stderr_file):
            with open(stderr_file) as f:
                logging.error("".join(f.readlines()))
        raise Exception(f"Process startup failed on {host if host else 'localhost'}: {command}")
    pid = int(pid)
    logging.info(f"PID: {pid}")
    return (pid, command)


def kill_process(host, pid, signal="TERM"):
    assert signal in ("TERM", "KILL", "INT")
    logging.info(f"Killing PGID {pid} on {host}")
    args = ["kill", f"-{signal}", f"-{pid}"]
    if not is_local(host):
        args = ["ssh", host, "--"] + args
    res = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    if res.returncode != 0:
        logging.error(f"error: {res.returncode} {res.stdout.decode().strip()} {res.stderr.decode().strip()}")
        return False
    return True