Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import json
import logging
import os
import socket
import subprocess
from multiprocessing import Pool
HOSTNAME = socket.gethostname()
CLUSTER_FILENAME = "cluster.json"
def kill_process_pool(args):
kill_fn, node, process = args
kill_fn(node, process)
def default_kill_fn(node, process):
return kill_process(node, process["pid"])
class Cluster:
@staticmethod
def deserialize(file):
data = json.load(file)
return Cluster(data["workdir"], data["nodes"])
def __init__(self, workdir, nodes=None):
if nodes is None:
nodes = {}
self.workdir = workdir
self.nodes = nodes
def add(self, node, pid, cmd, key=None, **kwargs):
data = {
"cmd": cmd,
"pid": pid
}
if key:
data["key"] = key
data.update(kwargs)
self.nodes.setdefault(node, []).append(data)
def processes(self):
for (node, processes) in self.nodes.items():
for process in processes:
yield (node, process)
def kill(self, kill_fn=None):
if kill_fn is None:
kill_fn = default_kill_fn
with Pool() as pool:
pool.map(kill_process_pool, [(kill_fn, node, process) for (node, process) in self.processes()])
def get_processes_by_key(self, key):
def gen():
for (node, process) in self.processes():
if process.get("key") == key:
yield (node, process)
return list(gen())
def get_monitor_for_node(self, node):
processes = self.nodes.get(node, [])
for process in processes:
if process.get("key") == "monitor":
return process
return None
def serialize(self, file):
json.dump({
"workdir": self.workdir,
"nodes": self.nodes
}, file, indent=2)
def is_local(host):
return host == HOSTNAME or host == "localhost" or host == socket.gethostbyname(HOSTNAME)
def start_process(commands, host=None, workdir=None, name=None, env=None, init_cmd=""):
if not workdir:
workdir = os.getcwd()
workdir = os.path.abspath(workdir)
if init_cmd:
init_cmd = f"{init_cmd} || exit 1"
args = []
if env:
args += ["env"]
for (key, val) in env.items():
args += [f"{key}={val}"]
args += [str(cmd) for cmd in commands]
if not name:
name = "process"
output = os.path.join(workdir, name)
logging.info(f"Running {' '.join(str(c) for c in commands)} on {host}")
stdout_file = f"{output}.out"
stderr_file = f"{output}.err"
command = f"""
cd {workdir} || exit 1
{init_cmd}
ulimit -c unlimited
{' '.join(args)} > {stdout_file} 2> {stderr_file} &
ps -ho pgid $!
""".strip()
cmd_args = []
if host:
cmd_args += ["ssh", host]
else:
cmd_args += ["setsid"]
cmd_args += ["/bin/bash"]
process = subprocess.Popen(cmd_args, cwd=workdir, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stdin=subprocess.PIPE)
out, err = process.communicate(command.encode())
pid = out.strip()
if not pid:
logging.error(
f"Process startup failed with status: {process.returncode}, stderr: {err.decode()}, stdout: {out.decode()}")
if os.path.isfile(stderr_file):
with open(stderr_file) as f:
logging.error("".join(f.readlines()))
raise Exception(f"Process startup failed on {host if host else 'localhost'}: {command}")
pid = int(pid)
logging.info(f"PID: {pid}")
return (pid, command)
def kill_process(host, pid, signal="TERM"):
assert signal in ("TERM", "KILL", "INT")
logging.info(f"Killing PGID {pid} on {host}")
args = ["kill", f"-{signal}", f"-{pid}"]
if not is_local(host):
args = ["ssh", host, "--"] + args
res = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if res.returncode != 0:
logging.error(f"error: {res.returncode} {res.stdout.decode().strip()} {res.stderr.decode().strip()}")
return False
return True