Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import dataclasses
import json
import logging
import os
import socket
import subprocess
from multiprocessing import Pool
from typing import Dict, Iterator, List, Optional, Tuple
import dacite
HOSTNAME = socket.gethostname()
@dataclasses.dataclass
class Process:
# Command used to launch the process
cmd: str
pid: int
key: Optional[str] = None
attrs: Dict = dataclasses.field(default_factory=lambda: {})
@dataclasses.dataclass
class Node:
processes: List[Process]
class Cluster:
"""
Stores information about processes running on a cluster.
The information about the cluster can be serialized and deserialized and the processes
may be later killed (even when running on a remote node, if it's accessible by SSH).
"""
@staticmethod
def deserialize(file) -> "Cluster":
data = json.load(file)
nodes = {k: dacite.from_dict(Node, v) for (k, v) in data["nodes"].items()}
return Cluster(data["workdir"], nodes)
def __init__(self, workdir: str, nodes: Optional[Dict[str, Node]] = None):
if nodes is None:
nodes = {}
self.workdir = workdir
self.nodes = nodes
def add(self, node: str, pid: int, cmd: str, key: Optional[str] = None, **kwargs) -> Process:
process = Process(
cmd=cmd,
pid=pid,
key=key,
attrs=kwargs
)
if node not in self.nodes:
self.nodes[node] = Node(processes=[])
self.nodes[node].processes.append(process)
return process
def processes(self) -> Iterator[Tuple[str, Process]]:
for (address, node) in self.nodes.items():
for process in node.processes:
yield (address, process)
def kill(self, kill_fn=None):
if kill_fn is None:
kill_fn = lambda node, process: kill_process(node, process.pid)
with Pool() as pool:
pool.map(kill_process_pool,
[(kill_fn, node, process) for (node, process) in self.processes()])
def get_processes(self, key: str = None, node: str = None):
for (address, process) in self.processes():
if key is not None and process.key != key:
continue
if node is not None and address != node:
continue
yield (node, process)
def serialize(self, file):
json.dump({
"workdir": self.workdir,
"nodes": {address: dataclasses.asdict(node) for (address, node) in
self.nodes.items()}
}, file, indent=2)
def __repr__(self):
out = f"Workdir: {self.workdir}\n"
out += "Nodes:\n"
for (address, node) in self.nodes.items():
out += f"{address}: {node}\n"
return out
def is_local(host: str) -> bool:
"""
Returns true if the given `host` is the local computer.
"""
return host == HOSTNAME or host == "localhost" or host == socket.gethostbyname(HOSTNAME)
def start_process(
commands: List[str],
host: Optional[str] = None,
workdir: Optional[str] = None,
modules: Optional[List[str]] = None,
name: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
pyenv: Optional[str] = None,
init_cmd: Optional[List[str]] = None
):
"""
Start a process on the given `host`.
:param commands: List of commands to run.
:param host: Hostname where to start the process.
:param workdir: Working directory of the process.
:param modules: LMOD modules to load on the host.
:param name: Name (used for stdout/stderr files in the working directory).
:param env: Environment variables passed to the process.
:param pyenv: Python virtual environment that should be sourced by the process.
:param init_cmd: Initialization commands performed at the start of the process.
"""
if not workdir:
workdir = os.getcwd()
workdir = os.path.abspath(workdir)
init_cmd = init_cmd if init_cmd is not None else init_cmd
init_cmd = list(init_cmd)
if modules is not None:
init_cmd += [f"ml {' '.join(modules)}"]
if pyenv:
assert os.path.isabs(pyenv)
init_cmd += [f"source {pyenv}/bin/activate"]
args = []
if env:
args += ["env"]
for (key, val) in env.items():
args += [f"{key}={val}"]
args += [str(cmd) for cmd in commands]
if not name:
name = "process"
output = os.path.join(workdir, name)
stdout_file = f"{output}.out"
stderr_file = f"{output}.err"
command = f"""
cd {workdir} || exit 1
{' && '.join(f"{{ {cmd} || exit 1; }}" for cmd in init_cmd)}
ulimit -c unlimited
{' '.join(args)} > {stdout_file} 2> {stderr_file} &
ps -ho pgid $!
""".strip()
cmd_args = []
if host:
cmd_args += ["ssh", host]
else:
cmd_args += ["setsid"]
cmd_args += ["/bin/bash"]
process = subprocess.Popen(cmd_args, cwd=workdir, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stdin=subprocess.PIPE)
out, err = process.communicate(command.encode())
pid = out.strip()
if not pid:
logging.error(
f"Process startup failed with status: {process.returncode}, stderr: {err.decode()}, stdout: {out.decode()}")
if os.path.isfile(stderr_file):
with open(stderr_file) as f:
logging.error("".join(f.readlines()))
raise Exception(f"Process startup failed on {host if host else 'localhost'}: {command}")
return (int(pid), command)
def kill_process_pool(args):
kill_fn, node, process = args
kill_fn(node, process)
def kill_process(host: str, pid: int, signal="TERM"):
"""
Kill a process with the given `pid` on the specified `host`
:param host: Hostname where the process is located.
:param pid: PID of the process to kill.
:param signal: Signal used to kill the process. One of "TERM", "KILL" or "INT".
"""
assert signal in ("TERM", "KILL", "INT")
logging.debug(f"Killing PGID {pid} on {host}")
args = ["kill", f"-{signal}", f"-{pid}"]
if not is_local(host):
args = ["ssh", host, "--"] + args
res = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if res.returncode != 0:
logging.error(
f"error: {res.returncode} {res.stdout.decode().strip()} {res.stderr.decode().strip()}")
return False
return True