Skip to content
Snippets Groups Projects
cluster.py 4.05 KiB
Newer Older
  • Learn to ignore specific revisions
  • import json
    import logging
    import os
    import socket
    import subprocess
    from multiprocessing import Pool
    
    HOSTNAME = socket.gethostname()
    CLUSTER_FILENAME = "cluster.json"
    
    
    def kill_process_pool(args):
        kill_fn, node, process = args
        kill_fn(node, process)
    
    
    def default_kill_fn(node, process):
        return kill_process(node, process["pid"])
    
    
    class Cluster:
        @staticmethod
        def deserialize(file):
            data = json.load(file)
            return Cluster(data["workdir"], data["nodes"])
    
        def __init__(self, workdir, nodes=None):
            if nodes is None:
                nodes = {}
    
            self.workdir = workdir
            self.nodes = nodes
    
        def add(self, node, pid, cmd, key=None, **kwargs):
            data = {
                "cmd": cmd,
                "pid": pid
            }
            if key:
                data["key"] = key
            data.update(kwargs)
    
            self.nodes.setdefault(node, []).append(data)
    
        def processes(self):
            for (node, processes) in self.nodes.items():
                for process in processes:
                    yield (node, process)
    
        def kill(self, kill_fn=None):
            if kill_fn is None:
                kill_fn = default_kill_fn
    
            with Pool() as pool:
                pool.map(kill_process_pool, [(kill_fn, node, process) for (node, process) in self.processes()])
    
        def get_processes_by_key(self, key):
            def gen():
                for (node, process) in self.processes():
                    if process.get("key") == key:
                        yield (node, process)
            return list(gen())
    
        def get_monitor_for_node(self, node):
            processes = self.nodes.get(node, [])
            for process in processes:
                if process.get("key") == "monitor":
                    return process
            return None
    
        def serialize(self, file):
            json.dump({
                "workdir": self.workdir,
                "nodes": self.nodes
            }, file, indent=2)
    
    
    def is_local(host):
        return host == HOSTNAME or host == "localhost" or host == socket.gethostbyname(HOSTNAME)
    
    
    def start_process(commands, host=None, workdir=None, name=None, env=None, init_cmd=""):
        if not workdir:
            workdir = os.getcwd()
        workdir = os.path.abspath(workdir)
    
        if init_cmd:
            init_cmd = f"{init_cmd} || exit 1"
    
        args = []
        if env:
            args += ["env"]
            for (key, val) in env.items():
                args += [f"{key}={val}"]
    
        args += [str(cmd) for cmd in commands]
    
        if not name:
            name = "process"
        output = os.path.join(workdir, name)
    
        logging.info(f"Running {' '.join(str(c) for c in commands)} on {host}")
    
        stdout_file = f"{output}.out"
        stderr_file = f"{output}.err"
        command = f"""
    cd {workdir} || exit 1
    {init_cmd}
    ulimit -c unlimited
    {' '.join(args)} > {stdout_file} 2> {stderr_file} &
    ps -ho pgid $!
    """.strip()
    
        cmd_args = []
        if host:
            cmd_args += ["ssh", host]
        else:
            cmd_args += ["setsid"]
        cmd_args += ["/bin/bash"]
        process = subprocess.Popen(cmd_args, cwd=workdir, stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE, stdin=subprocess.PIPE)
        out, err = process.communicate(command.encode())
        pid = out.strip()
    
        if not pid:
            logging.error(
                f"Process startup failed with status: {process.returncode}, stderr: {err.decode()}, stdout: {out.decode()}")
            if os.path.isfile(stderr_file):
                with open(stderr_file) as f:
                    logging.error("".join(f.readlines()))
            raise Exception(f"Process startup failed on {host if host else 'localhost'}: {command}")
        pid = int(pid)
        logging.info(f"PID: {pid}")
        return (pid, command)
    
    
    def kill_process(host, pid, signal="TERM"):
        assert signal in ("TERM", "KILL", "INT")
        logging.info(f"Killing PGID {pid} on {host}")
        args = ["kill", f"-{signal}", f"-{pid}"]
        if not is_local(host):
            args = ["ssh", host, "--"] + args
        res = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        if res.returncode != 0:
            logging.error(f"error: {res.returncode} {res.stdout.decode().strip()} {res.stderr.decode().strip()}")
            return False
        return True