Commit ea800f60 authored by Jakub Beránek's avatar Jakub Beránek
Browse files

Add cluster kill functionality

parent fd26ceb2
import json
import logging
import os
import socket
import subprocess
from multiprocessing import Pool
HOSTNAME = socket.gethostname()
CLUSTER_FILENAME = "cluster.json"
def kill_process_pool(args):
kill_fn, node, process = args
kill_fn(node, process)
def default_kill_fn(node, process):
return kill_process(node, process["pid"])
class Cluster:
@staticmethod
def deserialize(file):
data = json.load(file)
return Cluster(data["workdir"], data["nodes"])
def __init__(self, workdir, nodes=None):
if nodes is None:
nodes = {}
self.workdir = workdir
self.nodes = nodes
def add(self, node, pid, cmd, key=None, **kwargs):
data = {
"cmd": cmd,
"pid": pid
}
if key:
data["key"] = key
data.update(kwargs)
self.nodes.setdefault(node, []).append(data)
def processes(self):
for (node, processes) in self.nodes.items():
for process in processes:
yield (node, process)
def kill(self, kill_fn=None):
if kill_fn is None:
kill_fn = default_kill_fn
with Pool() as pool:
pool.map(kill_process_pool, [(kill_fn, node, process) for (node, process) in self.processes()])
def get_processes_by_key(self, key):
def gen():
for (node, process) in self.processes():
if process.get("key") == key:
yield (node, process)
return list(gen())
def get_monitor_for_node(self, node):
processes = self.nodes.get(node, [])
for process in processes:
if process.get("key") == "monitor":
return process
return None
def serialize(self, file):
json.dump({
"workdir": self.workdir,
"nodes": self.nodes
}, file, indent=2)
def is_local(host):
return host == HOSTNAME or host == "localhost" or host == socket.gethostbyname(HOSTNAME)
def start_process(commands, host=None, workdir=None, name=None, env=None, init_cmd=""):
if not workdir:
workdir = os.getcwd()
workdir = os.path.abspath(workdir)
if init_cmd:
init_cmd = f"{init_cmd} || exit 1"
args = []
if env:
args += ["env"]
for (key, val) in env.items():
args += [f"{key}={val}"]
args += [str(cmd) for cmd in commands]
if not name:
name = "process"
output = os.path.join(workdir, name)
logging.info(f"Running {' '.join(str(c) for c in commands)} on {host}")
stdout_file = f"{output}.out"
stderr_file = f"{output}.err"
command = f"""
cd {workdir} || exit 1
{init_cmd}
ulimit -c unlimited
{' '.join(args)} > {stdout_file} 2> {stderr_file} &
ps -ho pgid $!
""".strip()
cmd_args = []
if host:
cmd_args += ["ssh", host]
else:
cmd_args += ["setsid"]
cmd_args += ["/bin/bash"]
process = subprocess.Popen(cmd_args, cwd=workdir, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stdin=subprocess.PIPE)
out, err = process.communicate(command.encode())
pid = out.strip()
if not pid:
logging.error(
f"Process startup failed with status: {process.returncode}, stderr: {err.decode()}, stdout: {out.decode()}")
if os.path.isfile(stderr_file):
with open(stderr_file) as f:
logging.error("".join(f.readlines()))
raise Exception(f"Process startup failed on {host if host else 'localhost'}: {command}")
pid = int(pid)
logging.info(f"PID: {pid}")
return (pid, command)
def kill_process(host, pid, signal="TERM"):
assert signal in ("TERM", "KILL", "INT")
logging.info(f"Killing PGID {pid} on {host}")
args = ["kill", f"-{signal}", f"-{pid}"]
if not is_local(host):
args = ["ssh", host, "--"] + args
res = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if res.returncode != 0:
logging.error(f"error: {res.returncode} {res.stdout.decode().strip()} {res.stderr.decode().strip()}")
return False
return True
import logging
import os
import pathlib
import subprocess
import click as click
from cluster import CLUSTER_FILENAME, Cluster, HOSTNAME, start_process
CURRENT_DIR = pathlib.Path(__file__).absolute().parent
ROOT_DIR = CURRENT_DIR.parent
......@@ -15,61 +16,6 @@ def prepare_directory(path):
os.makedirs(path, exist_ok=True)
def start_process(commands, host=None, workdir=None, name=None, env=None, init_cmd=""):
if not workdir:
workdir = os.getcwd()
workdir = os.path.abspath(workdir)
if init_cmd:
init_cmd = f"{init_cmd} || exit 1"
args = []
if env:
args += ["env"]
for (key, val) in env.items():
args += [f"{key}={val}"]
args += [str(cmd) for cmd in commands]
if not name:
name = "process"
output = os.path.join(workdir, name)
logging.info(f"Running {' '.join(str(c) for c in commands)} on {host}")
stdout_file = f"{output}.out"
stderr_file = f"{output}.err"
command = f"""
cd {workdir} || exit 1
{init_cmd}
ulimit -c unlimited
{' '.join(args)} > {stdout_file} 2> {stderr_file} &
ps -ho pgid $!
""".strip()
cmd_args = []
if host:
cmd_args += ["ssh", host]
else:
cmd_args += ["setsid"]
cmd_args += ["/bin/bash"]
process = subprocess.Popen(cmd_args, cwd=workdir, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stdin=subprocess.PIPE)
out, err = process.communicate(command.encode())
pid = out.strip()
if not pid:
logging.error(
f"Process startup failed with status: {process.returncode}, stderr: {err.decode()}, stdout: {out.decode()}")
if os.path.isfile(stderr_file):
with open(stderr_file) as f:
logging.error("".join(f.readlines()))
raise Exception(f"Process startup failed on {host if host else 'localhost'}: {command}")
pid = int(pid)
logging.info(f"PID: {pid}")
return (pid, command)
def is_inside_pbs():
return "PBS_NODEFILE" in os.environ
......@@ -81,7 +27,7 @@ def get_pbs_nodes():
return [line.strip() for line in f]
def start_datasrv(node, workdir, env, init_cmd):
def start_datasrv(cluster, node, workdir, env, init_cmd):
datasrv_dir = workdir / f"{node}-datasrv"
prepare_directory(datasrv_dir)
......@@ -90,23 +36,26 @@ def start_datasrv(node, workdir, env, init_cmd):
name = "datasrv"
commands = ["python", "-m", "quake.datasrv", str(datasrv_data_dir), "--port", DATASRV_PORT]
start_process(commands, host=node, workdir=str(datasrv_dir), name=name, env=env,
init_cmd=init_cmd)
pid, cmd = start_process(commands, host=node, workdir=str(datasrv_dir), name=name, env=env,
init_cmd=init_cmd)
cluster.add(node, pid, cmd, key="datasrv")
def start_server(workers, workdir, env, init_cmd):
def start_server(cluster, workers, workdir, env, init_cmd):
workdir = workdir / "server"
prepare_directory(workdir)
commands = ["python", "-m", "quake.server", "--ds-port", DATASRV_PORT, "--workers",
",".join(workers)]
start_process(commands, workdir=str(workdir), name="server", env=env, init_cmd=init_cmd)
pid, cmd = start_process(commands, workdir=str(workdir), name="server", env=env,
init_cmd=init_cmd)
cluster.add(HOSTNAME, pid, cmd, key="server")
@click.command()
@click.argument("workdir")
@click.option("--init-cmd", default="")
def pbs_deploy(workdir, init_cmd):
def up(workdir, init_cmd):
nodes = get_pbs_nodes()
workdir = pathlib.Path(workdir).absolute()
......@@ -115,12 +64,33 @@ def pbs_deploy(workdir, init_cmd):
env = {}
env["PYTHONPATH"] = f'{ROOT_DIR}:{env.get("PYTHONPATH", "")}'
cluster = Cluster(str(workdir))
for node in nodes:
start_datasrv(node, workdir, env, init_cmd)
start_server(nodes, workdir, env, init_cmd)
start_datasrv(cluster, node, workdir, env, init_cmd)
start_server(cluster, nodes, workdir, env, init_cmd)
cluster_path = workdir / CLUSTER_FILENAME
logging.info(f"Writing cluster into {cluster_path}")
with open(cluster_path, "w") as f:
cluster.serialize(f)
@click.command()
@click.argument("workdir")
def down(workdir):
with open(os.path.join(workdir, CLUSTER_FILENAME)) as f:
cluster = Cluster.deserialize(f)
cluster.kill()
@click.group()
def cli():
pass
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
pbs_deploy()
cli.add_command(up)
cli.add_command(down)
cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment