Forked from
Ada Böhm / quake
25 commits behind the upstream repository.
-
Stanislav Bohm authoredStanislav Bohm authored
cluster.py 4.05 KiB
import json
import logging
import os
import socket
import subprocess
from multiprocessing import Pool
HOSTNAME = socket.gethostname()
CLUSTER_FILENAME = "cluster.json"
def kill_process_pool(args):
kill_fn, node, process = args
kill_fn(node, process)
def default_kill_fn(node, process):
return kill_process(node, process["pid"])
class Cluster:
@staticmethod
def deserialize(file):
data = json.load(file)
return Cluster(data["workdir"], data["nodes"])
def __init__(self, workdir, nodes=None):
if nodes is None:
nodes = {}
self.workdir = workdir
self.nodes = nodes
def add(self, node, pid, cmd, key=None, **kwargs):
data = {
"cmd": cmd,
"pid": pid
}
if key:
data["key"] = key
data.update(kwargs)
self.nodes.setdefault(node, []).append(data)
def processes(self):
for (node, processes) in self.nodes.items():
for process in processes:
yield (node, process)
def kill(self, kill_fn=None):
if kill_fn is None:
kill_fn = default_kill_fn
with Pool() as pool:
pool.map(kill_process_pool, [(kill_fn, node, process) for (node, process) in self.processes()])
def get_processes_by_key(self, key):
def gen():
for (node, process) in self.processes():
if process.get("key") == key:
yield (node, process)
return list(gen())
def get_monitor_for_node(self, node):
processes = self.nodes.get(node, [])
for process in processes:
if process.get("key") == "monitor":
return process
return None
def serialize(self, file):
json.dump({
"workdir": self.workdir,
"nodes": self.nodes
}, file, indent=2)
def is_local(host):
return host == HOSTNAME or host == "localhost" or host == socket.gethostbyname(HOSTNAME)
def start_process(commands, host=None, workdir=None, name=None, env=None, init_cmd=""):
if not workdir:
workdir = os.getcwd()
workdir = os.path.abspath(workdir)
if init_cmd:
init_cmd = f"{init_cmd} || exit 1"
args = []
if env:
args += ["env"]
for (key, val) in env.items():
args += [f"{key}={val}"]
args += [str(cmd) for cmd in commands]
if not name:
name = "process"
output = os.path.join(workdir, name)
logging.info(f"Running {' '.join(str(c) for c in commands)} on {host}")
stdout_file = f"{output}.out"
stderr_file = f"{output}.err"
command = f"""
cd {workdir} || exit 1
{init_cmd}
ulimit -c unlimited
{' '.join(args)} > {stdout_file} 2> {stderr_file} &
ps -ho pgid $!
""".strip()
cmd_args = []
if host:
cmd_args += ["ssh", host]
else:
cmd_args += ["setsid"]
cmd_args += ["/bin/bash"]
process = subprocess.Popen(cmd_args, cwd=workdir, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stdin=subprocess.PIPE)
out, err = process.communicate(command.encode())
pid = out.strip()
if not pid:
logging.error(
f"Process startup failed with status: {process.returncode}, stderr: {err.decode()}, stdout: {out.decode()}")
if os.path.isfile(stderr_file):
with open(stderr_file) as f:
logging.error("".join(f.readlines()))
raise Exception(f"Process startup failed on {host if host else 'localhost'}: {command}")
pid = int(pid)
logging.info(f"PID: {pid}")
return (pid, command)
def kill_process(host, pid, signal="TERM"):
assert signal in ("TERM", "KILL", "INT")
logging.info(f"Killing PGID {pid} on {host}")
args = ["kill", f"-{signal}", f"-{pid}"]
if not is_local(host):
args = ["ssh", host, "--"] + args
res = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if res.returncode != 0:
logging.error(f"error: {res.returncode} {res.stdout.decode().strip()} {res.stderr.decode().strip()}")
return False
return True