Commit 95a908e9 authored by Stanislav Bohm's avatar Stanislav Bohm
Browse files

First version of wrapper

parent eddf00b0
from .client import Client # noqa
from .functions import mpi_task, set_global_client, wait, wait_all # noqa
......@@ -11,12 +11,10 @@ logger = logging.getLogger(__name__)
class Client:
PY_JOB_ARGS = ("python3", "-m", "quake.job", "$TASK_ID", "$RANK", "$DS_PORT")
DEFAULT_ENV = {}
def __init__(self, hostname="localhost", port=8600):
self.connection = None
self.unsubmitted_tasks = []
self.loop = asyncio.get_event_loop()
self.id_counter = 0
self._connect(hostname, port)
......@@ -31,60 +29,36 @@ class Client:
logger.info("Connecting to server ...")
self.connection = self.loop.run_until_complete(connect())
def new_task(self, n_outputs, n_workers, config, keep=False, inputs=()):
task = Task(self.id_counter, n_outputs, n_workers, config, keep, inputs)
self.id_counter += 1
self.unsubmitted_tasks.append(task)
return task
def new_mpirun_task(self, n_outputs, n_workers, args, keep=False, task_data=None, inputs=()):
config = {
"type": "mpirun",
"args": args,
"env": self.DEFAULT_ENV
}
if task_data is not None:
assert isinstance(task_data, bytes)
config["data"] = task_data
return self.new_task(n_outputs, n_workers, config, keep, inputs)
def new_py_task(self, n_outputs, n_workers, keep=False, task_data=None, inputs=()):
return self.new_mpirun_task(n_outputs, n_workers, self.PY_JOB_ARGS, keep, task_data, inputs)
def upload_data(self, data, keep=False):
assert isinstance(data, list)
for d in data:
assert isinstance(d, bytes)
config = {
"type": "upload",
"data": data,
}
return self.new_task(1, len(data), config, keep, ())
def unkeep(self, task):
logger.debug("Unkeeping id=%s", task.task_id)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.connection.call("unkeep", task.task_id))
def _prepare_submit(self):
for task in self.unsubmitted_tasks:
def _prepare_submit(self, tasks):
for task in tasks:
assert task.state == TaskState.NEW
task.state = TaskState.SUBMITTED
tasks = [task.to_dict() for task in self.unsubmitted_tasks]
self.unsubmitted_tasks = []
return tasks
task.task_id = self.id_counter
self.id_counter += 1
return [task.to_dict() for task in tasks]
def submit(self):
logger.debug("Submitting %s tasks", len(self.unsubmitted_tasks))
tasks = self._prepare_submit()
if tasks:
self.loop.run_until_complete(self.connection.call("submit", tasks))
def submit(self, tasks):
logger.debug("Submitting %s tasks", len(tasks))
serialized_tasks = self._prepare_submit(tasks)
if serialized_tasks:
self.loop.run_until_complete(self.connection.call("submit", serialized_tasks))
def wait(self, task):
logger.debug("Waiting on task id=%s", task.task_id)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.connection.call("wait", task.task_id))
def wait_all(self, tasks):
ids = [task.task_id for task in tasks]
logger.debug("Waiting on tasks %s", ids)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.connection.call("wait_all", ids))
def gather(self, task, output_id):
logger.debug("Gathering task id=%s", task.task_id)
loop = asyncio.get_event_loop()
......
from .plan import Plan
from .wrapper import FunctionWrapper
global_plan = Plan()
global_client = None
def mpi_task(*, n_processes, n_outputs=1):
def _builder(fn):
return FunctionWrapper(fn, n_processes, n_outputs, global_plan)
return _builder
def wait(task):
_flush_global_plan()
global_client.wait(task)
def wait_all(tasks):
_flush_global_plan()
global_client.wait_all(tasks)
def reset_global_plan():
global_plan.take_tasks()
def set_global_client(client):
global global_client
global_client = client
def _flush_global_plan():
tasks = global_plan.take_tasks()
if tasks:
global_client.submit(tasks)
\ No newline at end of file
class Plan:
def __init__(self):
self.tasks = []
def add_task(self, task):
assert task.task_id is None
self.tasks.append(task)
def take_tasks(self):
tasks = self.tasks
self.tasks = []
return tasks
\ No newline at end of file
......@@ -44,3 +44,35 @@ class Task:
else:
assert isinstance(layout, Layout)
return TaskInput(self, output_ids, layout)
DEFAULT_ENV = {}
PY_JOB_ARGS = ("python3", "-m", "quake.job", "$TASK_ID", "$RANK", "$DS_PORT")
def new_mpirun_task(n_outputs, n_workers, args, keep=False, task_data=None, inputs=()):
config = {
"type": "mpirun",
"args": args,
"env": DEFAULT_ENV
}
if task_data is not None:
assert isinstance(task_data, bytes)
config["data"] = task_data
return Task(None, n_outputs, n_workers, config, keep, inputs)
def new_py_task(n_outputs, n_workers, keep=False, task_data=None, inputs=()):
return new_mpirun_task(n_outputs, n_workers, PY_JOB_ARGS, keep, task_data, inputs)
def upload_data(data, keep=False):
assert isinstance(data, list)
for d in data:
assert isinstance(d, bytes)
config = {
"type": "upload",
"data": data,
}
return self.new_task(1, len(data), config, keep, ())
from ..job.config import JobConfiguration
from .task import new_py_task
import cloudpickle
import pickle
def task_runner(jctx, input_data, python_job):
return python_job.run()
class PythonJob:
def __init__(self, pickled_fn):
self.pickled_fn = pickled_fn
def run(self):
result = cloudpickle.loads(self.pickled_fn)()
return [pickle.dumps(result)]
class FunctionWrapper:
def __init__(self, fn, n_processes, n_outputs, plan):
self.fn = fn
self.n_processes = n_processes
self.n_outputs = n_outputs
self.plan = plan
self._pickled_fn = None
def pickle_fn(self):
if self._pickled_fn:
return self._pickled_fn
else:
self._pickled_fn = cloudpickle.dumps(self.fn)
return self._pickled_fn
def __repr__(self):
return "<FunctionWrapper of '{}'>".format(self.fn.__class__.__name__)
def __call__(self, *args, keep=False, **kwargs):
inputs = []
payload = PythonJob(self.pickle_fn())
config = pickle.dumps(JobConfiguration(task_runner, self.n_outputs, payload))
task = new_py_task(self.n_outputs, self.n_processes, keep, config, inputs)
self.plan.add_task(task)
return task
\ No newline at end of file
class JobConfiguration:
def __init__(self, fn, n_outputs):
def __init__(self, fn, n_outputs, payload=None):
self.fn = fn
self.n_outputs = n_outputs
self.payload = payload
\ No newline at end of file
......@@ -90,7 +90,7 @@ class Job:
input_data = await asyncio.gather(*fs)
jctx = JobContext(rank, input_data)
output = config.fn(jctx, input_data)
output = config.fn(jctx, input_data, config.payload)
assert len(output) == config.n_outputs
for i, data in enumerate(output):
......
......@@ -127,6 +127,16 @@ class Server:
raise Exception("Task '{}' not found".format(task_id))
await _wait_for_task(task)
@abrpc.expose()
async def wait_all(self, task_ids):
fs = []
for task_id in task_ids:
task = self.state.tasks.get(task_id)
if task is None:
raise Exception("Task '{}' not found".format(task_id))
fs.append(_wait_for_task(task))
await asyncio.wait(fs)
@abrpc.expose()
async def unkeep(self, task_id):
tasks_to_remove = self.state.unkeep(task_id)
......
......@@ -16,6 +16,7 @@ ROOT_DIR = os.path.dirname(TESTS_DIR)
sys.path.insert(0, ROOT_DIR)
from quake.client import Client # noqa
from quake.client.functions import reset_global_plan, set_global_client # noqa
@pytest.fixture(scope="session")
......@@ -85,8 +86,14 @@ def client(docker_cluster):
# wait_for_port(7604)
# wait_for_port(7605)
reset_global_plan()
set_global_client(None)
yield client
reset_global_plan()
set_global_client(None)
print("Clean up")
for p in ps:
p.kill()
......
import quake.client as quake
import mpi4py
@quake.mpi_task(n_processes=1)
def my_function():
return 12
def test_wrapper_simple(client):
f = my_function(10, keep=True)
quake.set_global_client(client)
quake.wait(f)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment