Commit 7355a679 authored by Stanislav Bohm's avatar Stanislav Bohm
Browse files

Server update

parent bcd621bd
from .client import Client # noqa
from ..common.task import Task # noqa
\ No newline at end of file
import logging
import asyncio
import uvloop
import abrpc
from .task import Task, TaskState
uvloop.install()
logger = logging.getLogger(__name__)
class Client:
def __init__(self, server):
self.server = server
def submit(self, tasks):
self.server.loop.call_soon_threadsafe(self.server.submit, tasks)
def wait_for_task(self, task):
async def wait_for_task():
logger.debug("Waiting for task %s", task)
if task.state == TaskState.UNFINISHED:
event = asyncio.Event()
task.add_event(event)
await event.wait()
if task.state == TaskState.ERROR:
return task.error
f = asyncio.run_coroutine_threadsafe(wait_for_task(), loop=self.server.loop)
result = f.result()
if result is not None:
raise Exception(result)
def __init__(self, hostname="localhost", port=8600):
self.connection = None
self.unsubmmited_tasks = []
self.loop = asyncio.get_event_loop()
self._connect(hostname, port)
self.id_counter = 0
def _connect(self, hostname, port):
async def connect():
connection = abrpc.Connection(await asyncio.open_connection(hostname, port=port))
asyncio.ensure_future(connection.serve())
logger.info("Connection to server established")
return connection
logger.info("Connecting to server ...")
self.connection = self.loop.run_until_complete(connect())
def new_task(self, n_outputs, n_workers, args, keep=False, inputs=()):
task = Task(self.id_counter, n_outputs, n_workers, args, keep, inputs)
self.id_counter += 1
self.unsubmmited_tasks.append(task)
return task
def submit(self):
logger.debug("Submitting %s tasks", len(self.unsubmmited_tasks))
if not self.unsubmmited_tasks:
return
for task in self.unsubmmited_tasks:
assert task.state == TaskState.NEW
task.state = TaskState.SUBMITTED
tasks = [task.to_dict() for task in self.unsubmmited_tasks]
self.unsubmmited_tasks = []
self.loop.run_until_complete(self.connection.call("submit", tasks))
def wait(self, task):
logger.debug("Waiting on task id=%s", task.task_id)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.connection.call("wait", task.task_id))
\ No newline at end of file
import enum
from ..common.taskinput import TaskInput
class TaskState(enum.Enum):
NEW = 0
SUBMITTED = 1
REMOVED = 2
class Task:
def __init__(self, task_id, n_outputs, n_workers, args, keep, inputs):
self.task_id = task_id
self.inputs = inputs
self.n_outputs = n_outputs
self.n_workers = n_workers
self.args = tuple(args)
self.keep = keep
self.state = TaskState.NEW
def to_dict(self):
return {
"task_id": self.task_id,
"inputs": [inp.to_dict() for inp in self.inputs],
"n_outputs": self.n_outputs,
"n_workers": self.n_workers,
"args": self.args,
"keep": self.keep,
}
def output(self, output_id):
return TaskInput(self, output_id)
\ No newline at end of file
class TaskInput:
def __init__(self, task: "Task", output_id: int, layout=None):
assert 0 <= output_id < task.n_outputs
self.task = task
self.output_id = output_id
self.layout = layout
def to_dict(self):
return {
"task": self.task.task_id,
"output_id": self.output_id
}
@staticmethod
def from_dict(data, tasks):
return TaskInput(tasks[data["task"]], data["output_id"])
\ No newline at end of file
import aiofiles
# import aiofiles
import asyncio
import os
......@@ -18,24 +18,29 @@ class Object:
"size": self.size,
}
"""
async def remove(self):
async with self.lock:
if self.filename:
await aiofiles.os.remove(self.filename)
self.data = None
self.filename = None
async def get_data(self):
if self.data is not None:
return self.data
async with self.lock:
if self.filename:
async with aiofiles.open(self.filename, "rb") as f:
return await f.read()
else:
return None
"""
def get_data(self):
return self.data
"""
if self.data is not None:
return self.data
async with self.lock:
if self.filename:
async with aiofiles.open(self.filename, "rb") as f:
return await f.read()
else:
return None
async def map_to_fs(self, workdir):
async with self.lock:
if self.filename is not None:
......@@ -47,4 +52,5 @@ class Object:
async with aiofiles.open(filename, "wb") as f:
await f.write(self.data)
self.data = None
return filename
\ No newline at end of file
return filename
"""
\ No newline at end of file
......@@ -20,7 +20,7 @@ class Service:
self.stats_obj_fetched = 0
self.stats_obj_data_provided = 0
self.stats_obj_file_provided = 0
#self.stats_obj_file_provided = 0
async def _serve(self, connection, hostname, port):
await connection.serve()
......@@ -78,12 +78,13 @@ class Service:
validate_name(name)
obj = await self._get_object(name, hostname, port)
self.stats_obj_data_provided += 1
data = await obj.get_data()
if data is None:
# This can happen in case of racing with .remove()
raise Exception("Object removed")
return data
return obj.get_data()
#if data is None:
# # This can happen in case of racing with .remove()
# raise Exception("Object removed")
#return data
"""
@expose()
async def map_to_fs(self, name, hostname=None, port=None):
validate_name(name)
......@@ -94,6 +95,7 @@ class Service:
# This can happen in case of racing with .remove()
raise Exception("Object removed")
return filename
"""
@expose()
async def remove(self, name):
......@@ -101,13 +103,13 @@ class Service:
if obj_f is None:
raise Exception("Object not found")
del self.objects[name]
obj = await obj_f
await obj.remove()
#obj = await obj_f
#await obj.remove()
@expose()
async def get_stats(self):
return {
"obj_file_provided": self.stats_obj_file_provided,
#"obj_file_provided": self.stats_obj_file_provided,
"obj_data_provided": self.stats_obj_data_provided,
"obj_fetched": self.stats_obj_fetched,
"connections": len(self.connections),
......
from quake.common.task import Task, TaskInput # noqa
from quake.server.worker import Worker
import argparse
import asyncio
import logging
import os
......@@ -7,27 +8,35 @@ from .server import Server
logger = logging.getLogger(__name__)
def get_worker_hostnames():
if "QUAKE_WORKERS" not in os.environ:
raise Exception("Set 'QUAKE_WORKERS' env variable")
return os.environ.get("QUAKE_WORKERS").split(",")
#def get_worker_hostnames():
# if "QUAKE_WORKERS" not in os.environ:
# raise Exception("Set 'QUAKE_WORKERS' env variable")
# return os.environ.get("QUAKE_WORKERS").split(",")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8600)
parser.add_argument("--ds-port", type=int, default=8602)
parser.add_argument("--workers", type=str, default="localhost")
return parser.parse_args()
def main():
logging.basicConfig(level=0)
worker_hostnames = get_worker_hostnames()
server = Server(worker_hostnames)
args = parse_args()
server = Server(args.workers.split(","), args.ds_port)
async def handle(conn):
logger.info("New connection %s", conn)
logger.info("New client connection %s", conn)
await conn.serve(server)
logger.info("Connection %s closed", conn)
logger.info("Client connection %s closed", conn)
loop = asyncio.get_event_loop()
loop.run_until_complete(server.connect_to_workers())
loop.run_until_complete(server.connect_to_ds())
loop.run_until_complete(
asyncio.start_server(on_connection(handle), port=8600))
asyncio.start_server(on_connection(handle), port=args.port))
loop.run_forever()
......
import asyncio
import logging
import os
import tempfile
import threading
import abrpc
import uvloop
from quake.common.task import TaskState
from .task import TaskState, Task
from ..common.taskinput import TaskInput
# !!!!!!!!!!!!!!!
uvloop.install()
......@@ -36,7 +36,7 @@ from .worker import Worker
class Server:
def __init__(self, worker_hostnames, run_prefix=(), run_cwd=None):
def __init__(self, worker_hostnames, local_ds_port):
logger.debug("Starting QUake server")
workers = []
......@@ -46,69 +46,91 @@ class Server:
logger.info("Registering worker worker_id=%s host=%s", i, worker.hostname)
workers.append(worker)
self.id_counter = 0
# self.id_counter = 0
self.tasks = {}
self.ready_tasks = []
self.all_workers = workers
self.free_workers = list(workers)
self.processes = {}
self.run_prefix = tuple(run_prefix)
self.run_cwd = run_cwd
"""
def start(self):
assert self.loop is None
self.loop = asyncio.new_event_loop()
self.stop_event = asyncio.Event(loop=self.loop)
thread = threading.Thread(target=server_thread_main, args=(self,), daemon=True)
thread.start()
def stop(self):
logging.debug("Stopping server")
async def _helper():
self.stop_event.set()
future = asyncio.run_coroutine_threadsafe(_helper(), self.loop)
future.result()
"""
self.processes = {}
#self.run_prefix = tuple(run_prefix)
#self.run_cwd = run_cwd
self.ds_connections = {}
self.local_ds_connection = None
self.ds_port = local_ds_port
@abrpc.expose()
async def wait(self, task_id):
task = self.tasks.get(task_id)
if task is None:
raise Exception("Task '{}' not found".format(task_id))
state = task.state
if task.state == TaskState.UNFINISHED:
event = asyncio.Event()
task.add_event(event)
await event.wait()
state = task.state
if state == TaskState.FINISHED:
return
elif task.state == TaskState.ERROR:
raise Exception(task.error)
else:
assert 0
@abrpc.expose()
async def submit(self, tasks):
new_ready_tasks = False
new_tasks = set()
def new_id(self):
self.id_counter += 1
return self.id_counter
task_map = self.tasks
for tdict in tasks:
task_id = tdict["task_id"]
if task_id in task_map:
raise Exception("Task id ({}) already used".format(task_id))
def submit(self, tasks):
new_ready_tasks = False
new_tasks = set(tasks)
for task in tasks:
assert task.task_id is None
task_id = self.new_id()
task.task_id = task_id
for tdict in tasks:
task_id = tdict["task_id"]
task = Task(task_id, tdict["n_outputs"], tdict["n_workers"], tdict["args"], tdict["keep"])
logger.debug("Task %s submitted", task_id)
self.tasks[task_id] = task
task_map[task_id] = task
new_tasks.add(task)
tdict["_task"] = task
for tdict in tasks:
task = tdict["_task"]
unfinished_deps = 0
for t in task.deps:
inputs = [TaskInput.from_dict(data, task_map) for data in tdict["inputs"]]
deps = frozenset(inp.task for inp in inputs)
for t in deps:
assert t.state != TaskState.RELEASED
assert t.keep or t in new_tasks, "Dependency on not-keep task"
t.consumers.add(task)
if not t.is_ready():
if not t.state == TaskState.FINISHED:
unfinished_deps += 1
task.inputs = inputs
task.deps = deps
task.unfinished_deps = unfinished_deps
if not unfinished_deps:
new_ready_tasks = True
logger.debug("Task %s is ready", task_id)
logger.debug("Task %s is ready", task)
self.ready_tasks.append(task)
if new_ready_tasks:
self.schedule()
def schedule(self):
logger.debug("Scheduling ...")
for task in self.ready_tasks:
logger.debug("Scheduling ... top_3_tasks: %s", self.ready_tasks[:3])
for task in self.ready_tasks[:]:
if task.n_workers <= len(self.free_workers):
workers = self.free_workers[:task.n_workers]
del self.free_workers[:task.n_workers]
self.ready_tasks.remove(task)
self._start_task(task, workers)
logger.debug("End of scheduling")
def _start_task(self, task, workers):
logger.debug("Starting task %s on %s", task, workers)
......@@ -116,18 +138,19 @@ class Server:
assert task not in self.processes
hostnames = ",".join(worker.hostname for worker in workers)
command = self.run_prefix
command = () # self.run_prefix
command += ("mpirun", "--host", hostnames, "--np", str(task.n_workers), "--map-by", "node")
command += task.args
asyncio.ensure_future(self._exec(task, command))
asyncio.ensure_future(self._exec(task, command, workers))
async def _exec(self, task, args):
print("ARGS", args)
async def _exec(self, task, args, workers):
with tempfile.TemporaryFile() as stdout_file:
with tempfile.TemporaryFile() as stderr_file:
process = await asyncio.create_subprocess_exec(
*args, cwd=self.run_cwd, loop=self.loop, stderr=stderr_file, stdout=stdout_file, stdin=asyncio.subprocess.DEVNULL)
*args, stderr=stderr_file, stdout=stdout_file, stdin=asyncio.subprocess.DEVNULL)
exitcode = await process.wait()
new_ready_tasks = False
self.free_workers.extend(workers)
if exitcode != 0:
logger.debug("Task %s FAILED", task)
stderr_file.seek(0)
......@@ -144,8 +167,23 @@ class Server:
logger.debug("Task %s finished", task)
task.state = TaskState.FINISHED
task.fire_events()
async def connect_to_workers(self):
fs = [asyncio.open_connection(w.hostname, port=8500) for w in self.all_workers]
for t in task.consumers:
t.unfinished_deps -= 1
if t.unfinished_deps <= 0:
assert t.unfinished_deps == 0
logger.debug("Task %s is ready", t)
self.ready_tasks.append(t)
new_ready_tasks = True
if new_ready_tasks:
self.schedule()
async def connect_to_ds(self):
async def connect(hostname, port):
connection = abrpc.Connection(await asyncio.open_connection(hostname, port=port))
asyncio.ensure_future(connection.serve())
return connection
fs = [connect(w.hostname, self.ds_port) for w in self.all_workers]
connections = await asyncio.gather(*fs)
print(connections)
self.ds_connections = dict(zip(self.all_workers, connections))
self.local_ds_connection = connections[0]
\ No newline at end of file
......@@ -5,20 +5,11 @@ class TaskState:
ERROR = 4
class TaskInput:
def __init__(self, task: "Task", output_id: int, layout=None):
assert 0 <= output_id < task.n_outputs
self.task = task
self.output_id = output_id
self.layout = layout
class Task:
def __init__(self, n_outputs, n_workers, args, inputs=(), keep=False):
self.task_id = None
self.inputs = inputs
def __init__(self, task_id, n_outputs, n_workers, args, keep):
self.task_id = task_id
self.inputs = []
self.n_outputs = n_outputs
self.n_workers = n_workers
self.args = tuple(args)
......@@ -26,7 +17,7 @@ class Task:
self.state = TaskState.UNFINISHED
self.keep = keep
self.deps = frozenset(inp.task for inp in inputs)
self.deps = None
self.unfinished_deps = None
self.consumers = set()
self.events = None
......
......@@ -36,6 +36,7 @@ def docker_cluster():
cmd_prefix = ["docker-compose", "exec", "-T", "--user", "mpirun", "--privileged"]
def make_cmds(cmd):
result = [
......@@ -63,19 +64,29 @@ def wait_for_port(port):
@pytest.fixture(scope="function")
def client(docker_cluster):
ps = []
for cmd in make_cmds(["/bin/bash", "-c", "kill `cat /tmp/datasrv` ; sleep 0.1 ; rm -rf /tmp/data ; (python3 -m quake.datasrv /tmp/data & echo $! > /tmp/datasrv)"]):
p = subprocess.Popen(cmd, cwd=DOCKER_DIR)
for cmd in make_cmds(["/bin/bash", "-c", "pgrep python3 | xargs kill; sleep 0.1 ; rm -rf /tmp/data ; python3 -m quake.datasrv /tmp/data"]):
p = subprocess.Popen(cmd, cwd=DOCKER_DIR, stdin=subprocess.DEVNULL)
ps.append(p)
time.sleep(1.5)
hostnames = ",".join(docker_cluster)
#cmd = cmd_prefix + ["mpi_head", "/bin/bash", "-c", "kill `pgrep -f quake.server` ; sleep 0.1; echo 'xxx'; python3 -m quake.server --workers={}".format(hostnames)]
cmd = cmd_prefix + ["mpi_head", "/bin/bash", "-c", "python3 -m quake.server --workers={}".format(hostnames)]
print(" ".join(cmd))
p = subprocess.Popen(cmd, cwd=DOCKER_DIR, stdin=subprocess.DEVNULL)
ps.append(p)
time.sleep(2)
client = Client(port=7600)
# mapped in docker-compose.yml
#wait_for_port(7602)
#wait_for_port(7603)
#wait_for_port(7604)
#wait_for_port(7605)
time.sleep(2)
yield None
yield client
print("Clean up")
for p in ps:
......@@ -83,6 +94,7 @@ def client(docker_cluster):
time.sleep(0.1)
for p in ps:
p.terminate()
p.wait()
@pytest.fixture()
......
......@@ -67,4 +67,5 @@ ADD default-mca-params.conf ${HOME}/.openmpi/mca-params.conf
RUN chown -R ${USER}:${USER} ${HOME}/.openmpi
EXPOSE 22
EXPOSE 8600
CMD ["/usr/sbin/sshd", "-D"]
......@@ -7,7 +7,8 @@ services:
environment:
- PYTHONPATH=/app
ports:
- 7602:8602
- "7602:8602"
- "7600:8600"
links:
- mpi_node
networks:
......@@ -29,4 +30,4 @@ services:
networks:
net:
driver: bridge
#driver: bridge
......@@ -29,10 +29,10 @@ def test_data_service(tmpdir, root_dir):
c2 = connection2.call("get_data", "x_1", "localhost", PORT1)
assert [b"123", b"123"] == await asyncio.gather(c1, c2)
path1 = await connection2.call("map_to_fs", "x_1")
assert isinstance(path1, str)
path2 = await connection2.call("map_to_fs", "x_1")