Skip to content
Snippets Groups Projects
cluster.py 6.33 KiB
Newer Older
  • Learn to ignore specific revisions
  • Jakub Beránek's avatar
    Jakub Beránek committed
    import dataclasses
    import json
    import logging
    import os
    import socket
    import subprocess
    from multiprocessing import Pool
    from typing import Dict, Iterator, List, Optional, Tuple
    
    import dacite
    
    HOSTNAME = socket.gethostname()
    
    
    @dataclasses.dataclass
    class Process:
        # Command used to launch the process
        cmd: str
        pid: int
        key: Optional[str] = None
        attrs: Dict = dataclasses.field(default_factory=lambda: {})
    
    
    @dataclasses.dataclass
    class Node:
        processes: List[Process]
    
    
    class Cluster:
        """
        Stores information about processes running on a cluster.
        The information about the cluster can be serialized and deserialized and the processes
        may be later killed (even when running on a remote node, if it's accessible by SSH).
        """
    
        @staticmethod
        def deserialize(file) -> "Cluster":
            data = json.load(file)
            nodes = {k: dacite.from_dict(Node, v) for (k, v) in data["nodes"].items()}
            return Cluster(data["workdir"], nodes)
    
        def __init__(self, workdir: str, nodes: Optional[Dict[str, Node]] = None):
            if nodes is None:
                nodes = {}
    
            self.workdir = workdir
            self.nodes = nodes
    
        def add(self, node: str, pid: int, cmd: str, key: Optional[str] = None, **kwargs) -> Process:
            process = Process(
                cmd=cmd,
                pid=pid,
                key=key,
                attrs=kwargs
            )
    
            if node not in self.nodes:
                self.nodes[node] = Node(processes=[])
            self.nodes[node].processes.append(process)
            return process
    
        def processes(self) -> Iterator[Tuple[str, Process]]:
            for (address, node) in self.nodes.items():
                for process in node.processes:
                    yield (address, process)
    
        def kill(self, kill_fn=None):
            if kill_fn is None:
                kill_fn = lambda node, process: kill_process(node, process.pid)
    
            with Pool() as pool:
                pool.map(kill_process_pool,
                         [(kill_fn, node, process) for (node, process) in self.processes()])
    
        def get_processes(self, key: str = None, node: str = None):
            for (address, process) in self.processes():
                if key is not None and process.key != key:
                    continue
                if node is not None and address != node:
                    continue
                yield (node, process)
    
        def serialize(self, file):
            json.dump({
                "workdir": self.workdir,
                "nodes": {address: dataclasses.asdict(node) for (address, node) in
                          self.nodes.items()}
            }, file, indent=2)
    
        def __repr__(self):
            out = f"Workdir: {self.workdir}\n"
            out += "Nodes:\n"
            for (address, node) in self.nodes.items():
                out += f"{address}: {node}\n"
            return out
    
    
    def is_local(host: str) -> bool:
        """
        Returns true if the given `host` is the local computer.
        """
        return host == HOSTNAME or host == "localhost" or host == socket.gethostbyname(HOSTNAME)
    
    
    def start_process(
            commands: List[str],
            host: Optional[str] = None,
            workdir: Optional[str] = None,
            modules: Optional[List[str]] = None,
            name: Optional[str] = None,
            env: Optional[Dict[str, str]] = None,
            pyenv: Optional[str] = None,
            init_cmd: Optional[List[str]] = None
    ):
        """
        Start a process on the given `host`.
    
        :param commands: List of commands to run.
        :param host: Hostname where to start the process.
        :param workdir: Working directory of the process.
        :param modules: LMOD modules to load on the host.
        :param name: Name (used for stdout/stderr files in the working directory).
        :param env: Environment variables passed to the process.
        :param pyenv: Python virtual environment that should be sourced by the process.
        :param init_cmd: Initialization commands performed at the start of the process.
        """
        if not workdir:
            workdir = os.getcwd()
        workdir = os.path.abspath(workdir)
    
        init_cmd = init_cmd if init_cmd is not None else init_cmd
        init_cmd = list(init_cmd)
        if modules is not None:
            init_cmd += [f"ml {' '.join(modules)}"]
    
        if pyenv:
            assert os.path.isabs(pyenv)
            init_cmd += [f"source {pyenv}/bin/activate"]
    
        args = []
        if env:
            args += ["env"]
            for (key, val) in env.items():
                args += [f"{key}={val}"]
    
        args += [str(cmd) for cmd in commands]
    
        if not name:
            name = "process"
        output = os.path.join(workdir, name)
    
        stdout_file = f"{output}.out"
        stderr_file = f"{output}.err"
        command = f"""
    cd {workdir} || exit 1
    {' && '.join(f"{{ {cmd} || exit 1; }}" for cmd in init_cmd)}
    ulimit -c unlimited
    {' '.join(args)} > {stdout_file} 2> {stderr_file} &
    ps -ho pgid $!
    """.strip()
    
        cmd_args = []
        if host:
            cmd_args += ["ssh", host]
        else:
            cmd_args += ["setsid"]
        cmd_args += ["/bin/bash"]
        process = subprocess.Popen(cmd_args, cwd=workdir, stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE, stdin=subprocess.PIPE)
        out, err = process.communicate(command.encode())
        pid = out.strip()
    
        if not pid:
            logging.error(
                f"Process startup failed with status: {process.returncode}, stderr: {err.decode()}, stdout: {out.decode()}")
            if os.path.isfile(stderr_file):
                with open(stderr_file) as f:
                    logging.error("".join(f.readlines()))
            raise Exception(f"Process startup failed on {host if host else 'localhost'}: {command}")
        return (int(pid), command)
    
    
    def kill_process_pool(args):
        kill_fn, node, process = args
        kill_fn(node, process)
    
    
    def kill_process(host: str, pid: int, signal="TERM"):
        """
        Kill a process with the given `pid` on the specified `host`
        :param host: Hostname where the process is located.
        :param pid: PID of the process to kill.
        :param signal: Signal used to kill the process. One of "TERM", "KILL" or "INT".
        """
        assert signal in ("TERM", "KILL", "INT")
        logging.debug(f"Killing PGID {pid} on {host}")
        args = ["kill", f"-{signal}", f"-{pid}"]
        if not is_local(host):
            args = ["ssh", host, "--"] + args
        res = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        if res.returncode != 0:
            logging.error(
                f"error: {res.returncode} {res.stdout.decode().strip()} {res.stderr.decode().strip()}")
            return False
        return True