Commit 38a55386 authored by Stanislav Bohm's avatar Stanislav Bohm
Browse files

Wrapper improved

parent 95a908e9
from .client import Client # noqa
from .functions import mpi_task, set_global_client, wait, wait_all # noqa
from quake.client.base.client import Client # noqa
from .functions import mpi_task, arg # noqa
from .functions import set_global_client, wait, wait_all, remove, gather # noqa
from . import job
\ No newline at end of file
......@@ -4,7 +4,7 @@ import logging
import abrpc
import uvloop
from .task import Task, TaskState
from quake.client.base.task import TaskState
uvloop.install()
logger = logging.getLogger(__name__)
......@@ -29,10 +29,18 @@ class Client:
logger.info("Connecting to server ...")
self.connection = self.loop.run_until_complete(connect())
def unkeep(self, task):
def remove(self, task):
logger.debug("Unkeeping id=%s", task.task_id)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.connection.call("unkeep", task.task_id))
if not task.keep:
raise Exception("'keep' flag is not set for task")
if task.state == TaskState.NEW:
pass # Do nothing
elif task.state == TaskState.SUBMITTED:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.connection.call("unkeep", task.task_id))
else:
raise Exception("Invalid task state")
task.keep = False
def _prepare_submit(self, tasks):
for task in tasks:
......@@ -59,7 +67,7 @@ class Client:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.connection.call("wait_all", ids))
def gather(self, task, output_id):
def gather(self, task, output_id=None):
logger.debug("Gathering task id=%s", task.task_id)
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.connection.call("gather", task.task_id, output_id))
import enum
from ..common.taskinput import TaskInput
from ..common.layout import Layout
from quake.common.taskinput import TaskInput
from quake.common.layout import Layout
class TaskState(enum.Enum):
NEW = 0
SUBMITTED = 1
REMOVED = 2
class Task:
......@@ -32,21 +31,22 @@ class Task:
"keep": self.keep,
}
def output(self, output_ids="all", layout="all_to_all"):
if isinstance(output_ids, int):
output_ids = [output_ids]
elif output_ids == "all":
output_ids = list(range(self.n_outputs))
if layout == "all_to_all":
layout = Layout(0, 0, 0, self.n_workers * len(output_ids))
elif layout == "cycle":
layout = Layout(1, 0, 0, 1)
else:
assert isinstance(layout, Layout)
return TaskInput(self, output_ids, layout)
DEFAULT_ENV = {}
def make_input(task, output_ids="all", layout="all_to_all"):
if isinstance(output_ids, int):
output_ids = [output_ids]
elif output_ids == "all":
output_ids = list(range(task.n_outputs))
if layout == "all_to_all":
layout = Layout(0, 0, 0, task.n_workers * len(output_ids))
elif layout == "scatter":
layout = Layout(1, 0, 0, 1)
else:
assert isinstance(layout, Layout)
return TaskInput(task, output_ids, layout)
DEFAULT_ENV = {"PYTHONPATH": None} # None = use server value if possible
PY_JOB_ARGS = ("python3", "-m", "quake.job", "$TASK_ID", "$RANK", "$DS_PORT")
......
from .plan import Plan
from .wrapper import FunctionWrapper
from .base.plan import Plan
from .wrapper import FunctionWrapper, ArgConfig
import pickle
global_plan = Plan()
global_client = None
......@@ -12,6 +15,21 @@ def mpi_task(*, n_processes, n_outputs=1):
return _builder
def arg(name, layout="all_to_all"):
def _builder(fn):
if isinstance(fn, FunctionWrapper):
configs = fn.arg_configs
elif hasattr(fn, "_quake_args"):
configs = fn._quake_args
else:
configs = {}
fn._quake_args = configs
configs[name] = ArgConfig(layout)
return fn
return _builder
def wait(task):
_flush_global_plan()
global_client.wait(task)
......@@ -22,6 +40,21 @@ def wait_all(tasks):
global_client.wait_all(tasks)
def gather(task, output_id=None, collapse_single_output=True):
_flush_global_plan()
if output_id is None and task.n_outputs == 1 and collapse_single_output:
output_id = 0
result = global_client.gather(task, output_id)
if output_id is not None:
return [pickle.loads(r) for r in result]
else:
return [[pickle.loads(c) for c in r]for r in result]
def remove(task):
global_client.remove(task)
def reset_global_plan():
global_plan.take_tasks()
......
_rank = None
def _set_rank(rank):
global _rank
_rank = rank
def get_rank():
return _rank
\ No newline at end of file
from ..job.config import JobConfiguration
from .task import new_py_task
from .base.task import new_py_task, make_input, Task
from .job import _set_rank
import cloudpickle
import pickle
import inspect
import collections
def task_runner(jctx, input_data, python_job):
return python_job.run()
_set_rank(jctx.rank)
return python_job.run(input_data)
def _load(obj):
if isinstance(obj, bytes):
return pickle.loads(obj)
if len(obj) == 1:
return _load(obj[0])
return [_load(o) for o in obj]
class PythonJob:
def __init__(self, pickled_fn):
def __init__(self, pickled_fn, task_args, const_args):
self.pickled_fn = pickled_fn
self.task_args = task_args
self.const_args = const_args
def run(self):
result = cloudpickle.loads(self.pickled_fn)()
def run(self, input_data):
#kwargs = {name: pickle.loads(input_data[value]) for name, value in self.task_args.items()}
kwargs = self.const_args
for name, value in self.task_args.items():
kwargs[name] = _load(input_data[value])
result = cloudpickle.loads(self.pickled_fn)(**kwargs)
return [pickle.dumps(result)]
ArgConfig = collections.namedtuple("ArgConfig", "layout")
class FunctionWrapper:
def __init__(self, fn, n_processes, n_outputs, plan):
self.fn = fn
self.signature = inspect.signature(fn)
self.n_processes = n_processes
self.n_outputs = n_outputs
self.plan = plan
if hasattr(fn, "_quake_args"):
self.arg_configs = fn._quake_args
assert isinstance(self.arg_configs, dict)
delattr(fn, "_quake_args")
else:
self.arg_configs = {}
self._pickled_fn = None
......@@ -35,12 +64,30 @@ class FunctionWrapper:
self._pickled_fn = cloudpickle.dumps(self.fn)
return self._pickled_fn
def _prepare_inputs(self, args, kwargs):
binding = self.signature.bind(*args, **kwargs)
inputs = []
task_args = {}
const_args = {}
for name, value in binding.arguments.items():
if isinstance(value, Task):
arg_config = self.arg_configs.get(name)
if arg_config:
layout = arg_config.layout
else:
layout = "all_to_all"
task_args[name] = len(inputs)
inputs.append(make_input(value, layout=layout))
else:
const_args[name] = value
return inputs, task_args, const_args
def __repr__(self):
return "<FunctionWrapper of '{}'>".format(self.fn.__class__.__name__)
def __call__(self, *args, keep=False, **kwargs):
inputs = []
payload = PythonJob(self.pickle_fn())
inputs, task_args, const_args = self._prepare_inputs(args, kwargs)
payload = PythonJob(self.pickle_fn(), task_args, const_args)
config = pickle.dumps(JobConfiguration(task_runner, self.n_outputs, payload))
task = new_py_task(self.n_outputs, self.n_processes, keep, config, inputs)
self.plan.add_task(task)
......
......@@ -4,6 +4,8 @@ def compute_b_levels(tasks):
stack = []
to_compute = {}
for task in tasks.values():
if task.consumers is None:
continue
c = len(task.consumers)
to_compute[task] = c
if c == 0:
......
......@@ -3,6 +3,7 @@ import json
import logging
import random
import tempfile
import os
import abrpc
import uvloop
......@@ -39,15 +40,13 @@ from .worker import Worker
async def _wait_for_task(task):
if not task.keep:
raise Exception("Waiting on non-keep tasks are not allowed (task={})".format(task))
state = task.state
if task.state == TaskState.UNFINISHED:
event = asyncio.Event()
task.add_event(event)
await event.wait()
state = task.state
if state == TaskState.FINISHED:
if state == TaskState.FINISHED or state == TaskState.RELEASED:
return
elif task.state == TaskState.ERROR:
raise Exception(task.error)
......@@ -109,16 +108,24 @@ class Server:
self.local_ds_connection = None
self.ds_port = local_ds_port
@staticmethod
async def _gather_output(task, output_id):
workers = [random.choice(tuple(ws)) for ws in task.placement[output_id]]
assert len(workers) == task.n_workers
fs = [w.ds_connection.call("get_data", task.make_data_name(output_id, i)) for i, w in enumerate(workers)]
return await asyncio.gather(*fs)
@abrpc.expose()
async def gather(self, task_id, output_id):
task = self.state.tasks.get(task_id)
if task is None:
raise Exception("Task '{}' not found".format(task_id))
await _wait_for_task(task)
workers = [random.choice(tuple(ws)) for ws in task.placement[output_id]]
assert len(workers) == task.n_workers
fs = [w.ds_connection.call("get_data", task.make_data_name(output_id, i)) for i, w in enumerate(workers)]
return await asyncio.gather(*fs)
if output_id is None:
fs = [self._gather_output(task, output_id) for output_id in range(task.n_outputs)]
return await asyncio.gather(*fs)
else:
return await self._gather_output(task, output_id)
@abrpc.expose()
async def wait(self, task_id):
......@@ -182,6 +189,10 @@ class Server:
args.append(worker.hostname)
if "env" in task.config:
for name, value in task.config["env"].items():
if value is None:
value = os.environ.get(name)
if value is None:
continue
args.append("-x")
args.append("{}={}".format(name, value))
for arg in config_args:
......
......@@ -3,7 +3,6 @@ version: "2"
services:
mpi_head:
build: .
# image: openmpi
environment:
- PYTHONPATH=/app
ports:
......@@ -19,7 +18,6 @@ services:
mpi_node:
build: .
# image: openmpi
environment:
- PYTHONPATH=/app
networks:
......
import pytest
from quake.client.task import Task
from quake.client.base.task import Task
# TX[CPUS, Outputs]
......@@ -99,7 +96,7 @@ def test_greedy_match():
t1 = Task(1, 1, 3, None, False, [])
t2 = Task(2, 1, 1, None, False, [])
t3 = Task(3, 1, 2, None, False, [])
t4 = Task(4, 0, 3, None, False, [t1.output(0, "cycle"), t2.output(0, "all_to_all"), t3.output(0, "cycle")])
t4 = Task(4, 0, 3, None, False, [t1.output(0, "scatter"), t2.output(0, "all_to_all"), t3.output(0, "scatter")])
workers = make_workers(4)
state = State(workers)
state.add_tasks([t.to_dict() for t in [t1, t2, t3, t4]])
......
import quake.client as quake
import mpi4py
import pytest
@quake.mpi_task(n_processes=1)
def my_function():
def my_const():
return 12
def test_wrapper_simple(client):
f = my_function(10, keep=True)
@quake.mpi_task(n_processes=1)
def my_sum(a, b):
return a + b
@quake.mpi_task(n_processes=1)
def my_sum_c(a, b):
return a + b
@quake.mpi_task(n_processes=4)
def my_const4():
return 12 + quake.job.get_rank()
@quake.mpi_task(n_processes=4)
@quake.arg("a", layout="scatter")
def my_mul4(a, b):
return a * b
def test_wrapper_wait_and_gather(client):
quake.set_global_client(client)
quake.wait(f)
\ No newline at end of file
f = my_const()
quake.wait(f)
quake.wait(f)
with pytest.raises(Exception, match="flag is not set"):
quake.remove(f)
f = my_const(keep=True)
quake.wait(f)
quake.wait(f)
assert quake.gather(f, collapse_single_output=False) == [[12]]
assert quake.gather(f) == [12]
assert quake.gather(f, 0) == [12]
quake.remove(f)
with pytest.raises(Exception, match="flag is not set"):
quake.remove(f)
f = my_const(keep=True)
assert quake.gather(f, collapse_single_output=False) == [[12]]
assert quake.gather(f) == [12]
assert quake.gather(f, 0) == [12]
quake.remove(f)
with pytest.raises(Exception, match="flag is not set"):
quake.remove(f)
f = my_const4(keep=True)
assert quake.gather(f, collapse_single_output=False) == [[12, 13, 14, 15]]
assert quake.gather(f) == [12, 13, 14, 15]
assert quake.gather(f, 0) == [12, 13, 14, 15]
quake.remove(f)
def test_wrapper_args(client):
quake.set_global_client(client)
f = my_const()
g = my_sum(f, f)
h = my_sum(g, my_const())
j = my_sum(h, f)
g = my_sum_c(j, 7)
assert quake.gather(g) == [55]
f = my_const()
g = my_mul4(f, 2)
assert quake.gather(g) == [24] * 4
f = my_const4()
g = my_mul4(f, 2)
assert quake.gather(g) == [24, 26, 28, 30]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment