Commit 5214abbc authored by Stanislav Bohm's avatar Stanislav Bohm
Browse files

Layout support

parent b1f92730
......@@ -18,8 +18,8 @@ class Client:
self.connection = None
self.unsubmitted_tasks = []
self.loop = asyncio.get_event_loop()
self._connect(hostname, port)
self.id_counter = 0
self._connect(hostname, port)
def _connect(self, hostname, port):
async def connect():
......
import enum
from ..common.taskinput import TaskInput
from ..common.layout import Layout
class TaskState(enum.Enum):
......@@ -32,4 +33,10 @@ class Task:
}
def output(self, output_id, layout="all_to_all"):
if layout == "all_to_all":
layout = Layout(self.n_workers, 0, 0, 0, self.n_workers)
elif layout == "cycle":
layout = Layout(self.n_workers, 1, 0, 0, 1)
else:
assert isinstance(layout, Layout)
return TaskInput(self, output_id, layout)
class Layout:
__slots__ = ("size", "offset_r", "offset_c", "block_size_r", "block_size_c")
def __init__(self, size, offset_r=0, offset_c=0, block_size_r=0, block_size_c=0):
self.size = size
self.block_size_r = block_size_r
self.block_size_c = block_size_c
self.offset_r = offset_r
self.offset_c = offset_c
def iterate(self, rank):
offset = self.offset_r * rank + self.offset_c
for i in range(offset, offset + self.block_size_c + self.block_size_r * rank):
yield i % self.size
def serialize(self):
return [self.size, self.offset_r, self.offset_c, self.block_size_r, self.block_size_c]
@staticmethod
def deserialize(data):
return Layout(*data)
from .layout import Layout
class TaskInput:
__slots__ = ("task", "output_id", "layout")
def __init__(self, task, output_id: int, layout):
def __init__(self, task, output_id: int, layout: Layout):
assert 0 <= output_id < task.n_outputs
self.task = task
self.output_id = output_id
......@@ -11,9 +15,15 @@ class TaskInput:
return {
"task": self.task.task_id,
"output_id": self.output_id,
"layout": self.layout,
"layout": self.layout.serialize()
}
@staticmethod
def from_dict(data, tasks):
return TaskInput(tasks[data["task"]], data["output_id"], data["layout"])
return TaskInput(
tasks[data["task"]],
data["output_id"],
Layout.deserialize(data["layout"]))
def __repr__(self):
return "<Input task={} o={} l={}>".format(self.task.task_id, self.output_id, self.layout)
def make_data_name(task_id, output_id, part_id):
return "data_{}_{}_{}".format(task_id, output_id, part_id)
def ffff():
pass
return "data_{}_{}_{}".format(task_id, output_id, part_id)
\ No newline at end of file
......@@ -22,7 +22,6 @@ class Service:
self.stats_obj_fetched = 0
self.stats_obj_data_provided = 0
# self.stats_obj_file_provided = 0
async def _serve(self, connection, hostname, port):
await connection.serve()
......@@ -86,6 +85,16 @@ class Service:
# raise Exception("Object removed")
# return data
@expose()
async def get_sizes(self, names):
result = []
for name in names:
f_obj = self.objects.get(name)
if f_obj is None:
result.append(None)
continue
result.append((await f_obj).size)
return result
"""
@expose()
async def map_to_fs(self, name, hostname=None, port=None):
......@@ -103,10 +112,9 @@ class Service:
async def remove(self, name):
obj_f = self.objects.get(name)
if obj_f is None:
raise Exception("Object not found")
return False
del self.objects[name]
# obj = await obj_f
# await obj.remove()
return True
@expose()
async def get_stats(self):
......
......@@ -7,6 +7,7 @@ import random
import abrpc
from quake.common.layout import Layout
from quake.common.utils import make_data_name
logger = logging.getLogger(__name__)
......@@ -70,7 +71,8 @@ class Job:
await self.ds_connection.call("upload", name, data)
async def start(self):
logger.info("Starting task id=%s on rank=%s", self.task_id, self.rank)
rank = self.rank
logger.info("Starting task id=%s on rank=%s", self.task_id, rank)
await self.connect_to_ds()
config = await self.download_config()
......@@ -81,12 +83,14 @@ class Job:
fs = []
for inp_dict in inputs:
# TODO: Other layouts
assert inp_dict["layout"] == "all_to_all"
parts = range(inp_dict["n_parts"])
fs.append(self.download_input(inp_dict["task_id"], inp_dict["output_id"], parts))
layout = Layout.deserialize(inp_dict["layout"])
fs.append(self.download_input(
inp_dict["task_id"],
inp_dict["output_id"],
layout.iterate(rank)))
input_data = await asyncio.gather(*fs)
jctx = JobContext(self.rank, input_data)
jctx = JobContext(rank, input_data)
output = config.fn(jctx, input_data)
assert len(output) == config.n_outputs
......
......@@ -9,12 +9,10 @@ def compute_b_levels(tasks):
if c == 0:
stack.append(task)
stack = []
while stack:
task = stack.pop()
task.b_level = 1 + max((t.b_level for t in task.consumers), default=0)
for inp in task.inputs:
t = inp.task
for t in task.deps:
to_compute[t] -= 1
v = to_compute[t]
if v <= 0:
......
......@@ -57,9 +57,15 @@ async def _wait_for_task(task):
async def _remove_task(task):
fs = []
# The ordering of two following calls is important!
placement = task.placement
task.set_released()
for output_id in range(task.n_outputs):
for part in range(task.n_workers):
fs.append(_remove_from_workers(task.placement[output_id][part], task.make_data_name(output_id, part)))
fs.append(_remove_from_workers(placement[output_id][part], task.make_data_name(output_id, part)))
await asyncio.gather(*fs)
logger.debug("All parts of task %s was removed (%s calls)", task, len(fs))
......@@ -74,6 +80,16 @@ async def _remove_from_workers(workers, name):
await asyncio.wait(fs)
async def _download_sizes(task, workers):
fs = [w.ds_connection.call("get_sizes", [task.make_data_name(output_id, part_id) for output_id in range(task.n_outputs)])
for part_id, w in enumerate(workers)]
sizes = await asyncio.gather(*fs)
return [
[sizes[part_id][output_id] for part_id in range(task.n_workers)]
for output_id in range(task.n_outputs)
]
class Server:
def __init__(self, worker_hostnames, local_ds_port):
......@@ -81,8 +97,7 @@ class Server:
workers = []
for i, hostname in enumerate(worker_hostnames):
worker = Worker(hostname)
worker.worker_id = i
worker = Worker(i, hostname)
logger.info("Registering worker worker_id=%s host=%s", i, worker.hostname)
workers.append(worker)
......@@ -177,8 +192,8 @@ class Server:
asyncio.ensure_future(_remove_task(task))
self.schedule()
def _task_finished(self, task, workers):
for task in self.state.on_task_finished(task, workers):
def _task_finished(self, task, workers, sizes):
for task in self.state.on_task_finished(task, workers, sizes):
asyncio.ensure_future(_remove_task(task))
self.schedule()
......@@ -188,7 +203,7 @@ class Server:
fs = [workers[i].ds_connection.call("upload", task.make_data_name(0, i), data)
for i, data in enumerate(parts)]
await asyncio.wait(fs)
self._task_finished(task, workers)
self._task_finished(task, workers, [[len(data) for data in parts]])
except Exception as e:
logger.error(e)
self._task_failed(task, workers, "Upload failed: " + str(e))
......@@ -204,7 +219,7 @@ class Server:
{"task_id": inp.task.task_id,
"output_id": inp.output_id,
"n_parts": inp.task.n_workers,
"layout": inp.layout}
"layout": inp.layout.serialize()}
for inp in task.inputs
]
......@@ -249,7 +264,9 @@ class Server:
stdout = stdout_file.read().decode()
logger.info("Task id={} finished.\nStdout:\n{}\nStderr:\n{}\n".format(
task.task_id, stdout, stderr))
self._task_finished(task, workers)
sizes = await _download_sizes(task, workers)
logger.debug("Sizes of task=%s downloaded sizes=%s", task.task_id, sizes)
self._task_finished(task, workers, sizes)
finally:
if data_key:
await _remove_from_workers(workers, data_key)
......
......@@ -15,13 +15,59 @@ def _check_removal(task, tasks_to_remove):
tasks_to_remove.append(task)
def _task_b_level(task):
return task.b_level
def _release_task(task, tasks_to_remove):
for inp in task.inputs:
t = inp.task
if task in t.consumers:
t.consumers.remove(task)
_check_removal(t, tasks_to_remove)
_check_removal(task, tasks_to_remove)
def _transfer_costs(task, part_id, worker):
cost = 0
for inp in task.inputs:
t = inp.task
placement = t.placement[inp.output_id]
sizes = t.sizes[inp.output_id]
for i in inp.layout.iterate(part_id):
if worker not in placement[i]:
cost += sizes[i]
return cost
def _update_placement(task, workers):
for inp in task.inputs:
t = inp.task
placement = t.placement[inp.output_id]
for rank in range(t.n_workers):
worker = workers[rank]
for i in inp.layout.iterate(rank):
placement[i].add(worker)
def _choose_workers(task, workers):
result = []
workers = workers.copy()
for part_id in range(task.n_workers):
worker = min(workers, key=lambda w: _transfer_costs(task, part_id, w))
workers.remove(worker)
result.append(worker)
return result
class State:
def __init__(self, workers):
self.tasks = {}
self.ready_tasks = []
self.all_workers = workers
self.free_workers = list(workers)
self.free_workers = set(workers)
self.need_sort = False
def add_tasks(self, serialized_tasks):
new_ready_tasks = False
......@@ -45,7 +91,7 @@ class State:
task = tdict["_task"]
unfinished_deps = 0
inputs = [TaskInput.from_dict(data, task_map) for data in tdict["inputs"]]
deps = frozenset(inp.task for inp in inputs)
deps = tuple(frozenset(inp.task for inp in inputs))
for t in deps:
assert t.state != TaskState.RELEASED
assert t.keep or t in new_tasks, "Dependency on not-keep task"
......@@ -60,50 +106,66 @@ class State:
logger.debug("Task %s is ready", task)
self.ready_tasks.append(task)
self.need_sort |= new_ready_tasks
compute_b_levels(task_map)
return new_ready_tasks
def _fake_placement(self, task, placement, sizes):
task._fake_finish(placement, sizes)
if task in self.ready_tasks:
self.ready_tasks.remove(task)
for t in task.consumers:
t.unfinished_deps -= 1
if t.unfinished_deps <= 0:
assert t.unfinished_deps == 0
logger.debug("Task %s is ready", t)
self.ready_tasks.append(t)
self.need_sort = True
def schedule(self):
if self.need_sort:
self.ready_tasks.sort(key=_task_b_level)
logger.debug("Scheduling ... top_3_tasks: %s", self.ready_tasks[:3])
for task in self.ready_tasks[:]:
if task.n_workers <= len(self.free_workers):
workers = self.free_workers[:task.n_workers]
del self.free_workers[:task.n_workers]
self.ready_tasks.remove(task)
free_workers = self.free_workers
for idx in range(len(self.ready_tasks) - 1, -1, -1):
task = self.ready_tasks[idx]
if task.n_workers <= len(free_workers):
workers = _choose_workers(task, free_workers)
for worker in workers:
free_workers.remove(worker)
del self.ready_tasks[idx]
yield task, workers
logger.debug("End of scheduling")
if not free_workers:
break
def _release_task(self, task, tasks_to_remove):
for inp in task.inputs:
t = inp.task
if task in t.consumers:
t.consumers.remove(task)
_check_removal(t, tasks_to_remove)
_check_removal(task, tasks_to_remove)
logger.debug("End of scheduling")
def on_task_failed(self, task, workers, message):
# TODO: Remove inputs that was downloaded for execution but they are not in placement
logger.error("Task %s FAILED: %s", task, message)
task.set_error(message)
tasks_to_remove = []
self._release_task(task, tasks_to_remove)
_release_task(task, tasks_to_remove)
for t in task.recursive_consumers():
if t.state == TaskState.UNFINISHED:
t.set_error(message)
self.free_workers.extend(workers)
self.free_workers.update(workers)
return tasks_to_remove
def on_task_finished(self, task, workers):
def on_task_finished(self, task, workers, sizes):
logger.debug("Task %s finished", task)
task.set_finished(workers)
task.set_finished(workers, sizes)
tasks_to_remove = []
self._release_task(task, tasks_to_remove)
_release_task(task, tasks_to_remove)
for t in task.consumers:
t.unfinished_deps -= 1
if t.unfinished_deps <= 0:
assert t.unfinished_deps == 0
logger.debug("Task %s is ready", t)
self.ready_tasks.append(t)
self.free_workers.extend(workers)
self.need_sort = True
self.free_workers.update(workers)
return tasks_to_remove
def unkeep(self, task_id):
......
......@@ -24,9 +24,9 @@ class Task:
self.unfinished_deps = None
self.consumers = set()
self.events = None
self.events = None
self.error = None
self.placement = None # placement[output_id][part_id] -> set of workers where is data placed
self.sizes = None # sizes[output_id][part_id] -> sizes of data parts
self.b_level = None
......@@ -41,7 +41,7 @@ class Task:
for t in task.consumers:
if t not in tasks:
stack.append(t)
tasks.append(t)
tasks.add(t)
return tasks
def add_event(self, event):
......@@ -65,16 +65,28 @@ class Task:
def is_ready(self):
return self.unfinished_deps == 0
def set_finished(self, workers):
def set_finished(self, workers, sizes):
assert self.state == TaskState.UNFINISHED
assert len(workers) == self.n_workers
assert len(sizes) == self.n_outputs
assert not sizes or len(sizes[0]) == self.n_workers
self.state = TaskState.FINISHED
self.placement = [[{w} for w in workers] for _ in range(self.n_outputs)]
self.sizes = sizes
self._fire_events()
def _fake_finish(self, placement, sizes):
self.placement = placement
self.sizes = sizes
self.state = TaskState.FINISHED
def set_released(self):
assert self.state == TaskState.FINISHED
self.state = TaskState.RELEASED
self.placement = None
self.consumers = None
self.sizes = None
def __repr__(self):
return "<Task id={} w={}>".format(self.task_id, self.n_workers)
class Worker:
def __init__(self, hostname):
self.worker_id = None
def __init__(self, worker_id, hostname):
self.worker_id = worker_id
self.hostname = hostname
self.ds_connection = None
self.tasks = set()
def __repr__(self):
return "<Worker id={}>".format(self.worker_id)
from quake.common.taskinput import Layout
def test_layouts():
def check(layout):
return [list(layout.iterate(rank)) for rank in range(4)]
assert check(Layout(4, 0, 0, 0, 4)) == [
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
]
assert check(Layout(4, 1, 0, 0, 1)) == [
[0], [1], [2], [3],
]
assert check(Layout(2, 1, 0, 0, 1)) == [
[0], [1], [0], [1],
]
\ No newline at end of file
import pytest
from quake.client.task import Task
# TX[CPUS, Outputs]
#
# T1[2]
# | T3[2]
# T2[3] |
# |\ |
# | \ / \
# | \ / \
# | T4[4,2] T5[1]
# | (0) (1) /
# | / \ /
# |/ \ /
# T6[2] T7[2,2]
# \ (0)(1)
# \ / /
# T8[4]
from quake.server import Worker
from quake.server.state import State
def make_workers(count):
return [Worker(i, "w_{}".format(i)) for i in range(count)]
def make_plan1():
t1 = Task(1, 1, 2, None, False, [])
t2 = Task(2, 1, 3, None, False, [t1.output(0)])
t3 = Task(3, 1, 2, None, False, [])
t4 = Task(4, 2, 4, None, False, [t2.output(0), t3.output(0)])
t5 = Task(5, 1, 1, None, False, [t3.output(0)])
t6 = Task(6, 1, 2, None, False, [t4.output(0), t2.output(0), t2.output(0)])
t7 = Task(7, 2, 2, None, False, [t4.output(1), t5.output(0)])
t8 = Task(8, 1, 4, None, False, [t6.output(0), t7.output(1), t7.output(0)])
tasks = [t1, t2, t3, t4, t5, t6, t7, t8]
return [t.to_dict() for t in tasks]
def test_plan():
plan1 = make_plan1()
state = State(make_workers(4))
has_ready_tasks = state.add_tasks(plan1)
assert has_ready_tasks
assert state.tasks.get(8).b_level == 1
assert state.tasks.get(7).b_level == 2
assert state.tasks.get(6).b_level == 2
assert state.tasks.get(5).b_level == 3
assert state.tasks.get(4).b_level == 3
assert state.tasks.get(3).b_level == 4
assert state.tasks.get(2).b_level == 4
assert state.tasks.get(1).b_level == 5
s = list(state.schedule())
s.sort(key=lambda x: x[0].task_id)
s1, s3 = s
assert s1[0].task_id == 1
assert len(s1[1]) == 2
assert s3[0].task_id == 3
assert len(s3[1]) == 2
assert len(set(s3[1] + s1[1])) == 4
state.on_task_finished(s[0][0], s[0][1], [[100, 100]])
state.on_task_finished(s[1][0], s[1][1], [[20, 20]])
s = list(state.schedule())
s.sort(key=lambda x: x[0].task_id)
s2, s5 = s
assert s2[0].task_id == 2
assert len(s2[1]) == 3
assert len(set(s2[1] + s1[1])) == 3
assert s5[0].task_id == 5
assert len(s5[1]) == 1
assert len(set(s3[1] + s5[1])) == 2
assert len(set(s2[1] + s5[1])) == 4
def test_greedy_match():
# T4 inputs
# t1 t1 t1 | t2 | t3 t3
# ------------------------------
# 100 0 0 | 200 | 1 0
# 0 100 0 | 200 | 0 2
# 0 0 100 | 200 | 1 0
# placements
# | t1 t1 t1 | t2 | t3 t3
# ---------------------------------
# w0 | 0 0 0 | 0 | 0 0
# w1 | 0 0 0 | 0 | 1 0
# w2 | 0 0 0 | 200 | 1 2
# w3 | 100 0 0 | 200 | 0 0
t1 = Task(1, 1, 3, None, False, [])
t2 = Task(2, 1, 1, None, False, [])