Commit aa5a2203 authored by Stanislav Bohm's avatar Stanislav Bohm
Browse files

Reformatting

parent 02d8febd
from quake.client.base.client import Client # noqa
from . import job
from .functions import mpi_task, arg # noqa
from .functions import set_global_client, wait, wait_all, remove, gather # noqa
from .functions import ( # noqa
arg,
gather,
mpi_task,
remove,
set_global_client,
wait,
wait_all,
)
......@@ -31,7 +31,9 @@ class Client:
break
except ConnectionError as e:
i += 1
logger.error("Could not connect to server (attempt [%s,%s])", i, max_tries)
logger.error(
"Could not connect to server (attempt [%s,%s])", i, max_tries
)
if i == max_tries:
raise e
await asyncio.sleep(1.0)
......
import pickle
import os
import pickle
from .base.plan import Plan
from . import Client
from .wrapper import FunctionWrapper, ArgConfig, ResultProxy
from .base.plan import Plan
from .wrapper import ArgConfig, FunctionWrapper, ResultProxy
global_plan = Plan()
global_client = None
......@@ -11,7 +11,7 @@ global_client = None
# ===== DECORATORS =========================
def mpi_task(*, n_processes, n_outputs=1):
def mpi_task(*, n_processes, n_outputs=None):
def _builder(fn):
return FunctionWrapper(fn, n_processes, n_outputs, global_plan)
......
......@@ -4,9 +4,9 @@ import pickle
import cloudpickle
from .base.task import new_py_task, make_input, Task
from .job import _set_rank
from ..job.config import JobConfiguration
from .base.task import Task, make_input, new_py_task
from .job import _set_rank
def task_runner(jctx, input_data, python_job):
......@@ -23,10 +23,11 @@ def _load(obj):
class PythonJob:
def __init__(self, pickled_fn, task_args, const_args):
def __init__(self, pickled_fn, task_args, const_args, n_outputs):
self.pickled_fn = pickled_fn
self.task_args = task_args
self.const_args = const_args
self.n_outputs = n_outputs
def run(self, input_data):
# kwargs = {name: pickle.loads(input_data[value]) for name, value in self.task_args.items()}
......@@ -34,7 +35,22 @@ class PythonJob:
for name, value in self.task_args.items():
kwargs[name] = _load(input_data[value])
result = cloudpickle.loads(self.pickled_fn)(**kwargs)
return [pickle.dumps(result)]
if self.n_outputs is None:
return [pickle.dumps(result)]
else:
if not isinstance(result, (list, tuple)):
raise Exception(
"Multiple outputs were specified, result of python call has to be list or tuple, not {}".format(
type(result)
)
)
if self.n_outputs != len(result):
raise Exception(
"Invalid number of output produced. Function returns {} outputs, but {} is expected".format(
len(result), self.n_outputs
)
)
return [pickle.dumps(r) for r in result]
ArgConfig = collections.namedtuple("ArgConfig", "layout")
......@@ -87,11 +103,17 @@ class FunctionWrapper:
def __repr__(self):
return "<FunctionWrapper of '{}'>".format(self.fn.__class__.__name__)
def __call__(self, *args, keep=False, **kwargs):
def __call__(self, *args, keep=False, n_outputs=None, **kwargs):
inputs, task_args, const_args = self._prepare_inputs(args, kwargs)
payload = PythonJob(self.pickle_fn(), task_args, const_args)
config = pickle.dumps(JobConfiguration(task_runner, self.n_outputs, payload))
task = new_py_task(self.n_outputs, self.n_processes, keep, config, inputs)
if n_outputs is None:
n_outputs = self.n_outputs
if n_outputs is None:
real_n_outputs = 1
else:
real_n_outputs = n_outputs
payload = PythonJob(self.pickle_fn(), task_args, const_args, n_outputs)
config = pickle.dumps(JobConfiguration(task_runner, real_n_outputs, payload))
task = new_py_task(real_n_outputs, self.n_processes, keep, config, inputs)
self.plan.add_task(task)
return ResultProxy(task)
......
import asyncio
from datetime import datetime
import uvloop
from abrpc import expose, Connection
from abrpc import Connection, expose
from .obj import Object
from .monitoring import get_resources
from .obj import Object
uvloop.install()
......
......@@ -5,12 +5,12 @@ import logging
# import cloudpickle
import pickle
import random
from datetime import datetime
import abrpc
from quake.common.layout import Layout
from quake.common.utils import make_data_name
from datetime import datetime
logger = logging.getLogger(__name__)
......@@ -109,7 +109,12 @@ class Job:
input_data = await asyncio.gather(*fs)
jctx = JobContext(rank, input_data)
output = config.fn(jctx, input_data, config.payload)
assert len(output) == config.n_outputs
if len(output) != config.n_outputs:
raise Exception(
"Task produced output of size {} but {} was expected".format(
len(output), config.n_outputs
)
)
for i, data in enumerate(output):
await self.upload_data(i, data)
......
import json
from datetime import datetime
from pandas import DataFrame
import numpy as np
from pandas import DataFrame
class EventStream:
......
import logging
from .task import TaskState
logger = logging.getLogger(__file__)
......
......@@ -7,10 +7,10 @@ import tempfile
import abrpc
import uvloop
from aiofile import AIOFile, Writer
from .state import State
from .task import TaskState
from aiofile import AIOFile, Writer
# !!!!!!!!!!!!!!!
uvloop.install()
......@@ -348,7 +348,13 @@ class Server:
return connection
except ConnectionError as e:
error = e
logger.error("Failed to connected to %s:%s (attempt %s/%s)", hostname, port, i + 1, RETRY_COUNT)
logger.error(
"Failed to connected to %s:%s (attempt %s/%s)",
hostname,
port,
i + 1,
RETRY_COUNT,
)
await asyncio.sleep(1.0)
raise error
......
import logging
import sys
from ..common.taskinput import TaskInput
from .scheduler import compute_b_levels
from .task import Task, TaskState
from ..common.taskinput import TaskInput
logger = logging.getLogger(__file__)
......
......@@ -18,7 +18,6 @@ sys.path.insert(0, ROOT_DIR)
from quake.client import Client # noqa
from quake.client.functions import reset_global_plan, set_global_client # noqa
nodes = 3
......@@ -91,12 +90,14 @@ def client(docker_cluster, tmpdir):
hostnames = ",".join(docker_cluster)
# cmd = cmd_prefix + ["mpi_head", "/bin/bash", "-c", "kill `pgrep -f quake.server` ; sleep 0.1; echo 'xxx'; python3 -m quake.server --workers={}".format(hostnames)]
monitoring_file = str(logdir.join("monitoring"))
# monitoring_file = str(logdir.join("monitoring"))
cmd = cmd_prefix + [
"mpihead",
"/bin/bash",
"-c",
"python3 -m quake.server --debug --workers={} --monitoring=/tmp/monitoring".format(hostnames),
"python3 -m quake.server --debug --workers={} --monitoring=/tmp/monitoring".format(
hostnames
),
]
# print(" ".join(cmd))
popen_helper(cmd, logfile=logdir.join("server"))
......
import quake.client as quake
import pytest
import subprocess
from conftest import DOCKER_DIR
import json
import subprocess
import time
import pytest
from conftest import DOCKER_DIR
import quake.client as quake
@quake.mpi_task(n_processes=2)
def my_sleep():
......@@ -24,20 +26,27 @@ def monitor_client(client):
s3 = my_sleep2(s)
quake.wait_all([s2, s3])
import sys
time.sleep(2)
sys.stderr.write("Getting a /tmp/monitoring\n")
output = subprocess.check_output(["docker-compose", "exec", "-T", "mpihead", "cat", "/tmp/monitoring"], cwd=DOCKER_DIR)
#print(output)
#with open("/tmp/x", "wb") as f:
output = subprocess.check_output(
["docker-compose", "exec", "-T", "mpihead", "cat", "/tmp/monitoring"],
cwd=DOCKER_DIR,
)
# print(output)
# with open("/tmp/x", "wb") as f:
# f.write(output)
lines = output.decode().split("\n")[:-1]
assert len(lines) >= 12
data = [json.loads(line) for line in lines]
hostnames = set()
for value in data:
assert set(value.keys()).issubset({"timestamp", "resources", "service", "hostname", "events"})
assert set(value.keys()).issubset(
{"timestamp", "resources", "service", "hostname", "events"}
)
hostnames.add(value["hostname"])
assert hostnames == {"mpihead", "mpinode1", "mpinode2", "mpinode3"}
def test_monitoring(monitor_client):
pass
\ No newline at end of file
pass
from quake.client.base.task import Task, make_input
# TX[CPUS, Outputs]
#
# T1[2]
......
......@@ -6,10 +6,10 @@ import pytest
import quake.job
from quake.client.base.task import (
Task,
make_input,
new_mpirun_task,
upload_data,
new_py_task,
make_input,
upload_data,
)
......
import quake.client as quake
import pytest
import quake.client as quake
@quake.mpi_task(n_processes=1)
def my_const():
......@@ -107,6 +108,41 @@ def test_wrapper_args(client):
assert quake.gather(g) == [24, 26, 28, 30]
def test_three_tasks(client):
quake.set_global_client(client)
@quake.mpi_task(n_processes=1, n_outputs=4)
def my_preprocessing():
# Let us produce 4 pieces of something on 1 node MPI process
return ["something1", "something2", "something3", "something4"]
@quake.mpi_task(n_processes=4)
@quake.arg("my_data", layout="scatter")
def my_computation(my_config, my_data):
# This is called in 4 MPI processes
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
return "Computation at rank {}: configuration={}, data={}".format(
rank, my_config, my_data
)
data = my_preprocessing()
result = my_computation("my_configuration", data)
assert (
"\n".join(quake.gather(result))
== """Computation at rank 0: configuration=my_configuration, data=something1
Computation at rank 1: configuration=my_configuration, data=something2
Computation at rank 2: configuration=my_configuration, data=something3
Computation at rank 3: configuration=my_configuration, data=something4"""
)
# for i, r in enumerate(quake.gather(result)):
# print("Output {}: {}".format(i, r))
def test_wrapper_error(client):
quake.set_global_client(client)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment