...
 
Commits (6)
......@@ -12,6 +12,7 @@ message Task {
repeated int32 input_ids = 3;
optional int32 resource_request_index = 5 [default = -1];
optional bool result = 6;
optional string checkpoint_path = 7;
optional string label = 12;
optional bytes metadata = 13;
......@@ -62,6 +63,7 @@ message WorkerCommand {
REMOVE = 3;
DICTIONARY = 8;
UPDATE = 9;
LOAD_CHECKPOINT = 10;
}
required Type type = 1;
......@@ -74,6 +76,9 @@ message WorkerCommand {
repeated int32 task_inputs = 5;
optional int32 n_cpus = 6;
// TASK + LOAD_CHECKPOINT
optional string checkpoint_path = 7;
// SEND
optional string address = 10;
......@@ -88,17 +93,22 @@ message WorkerCommand {
message WorkerResponse {
enum Type {
FINISHED = 1;
TRANSFERED = 2;
FAILED = 3;
FINISHED_AND_CHECKPOINTING = 2;
TRANSFERED = 3;
FAILED = 4;
CHECKPOINT_WRITTEN = 5;
CHECKPOINT_WRITE_FAILED = 6;
CHECKPOINT_LOADED = 7;
CHECKPOINT_LOAD_FAILED = 8;
}
required Type type = 1;
required int32 id = 2;
// FINISHED
// FINISHED + CHECKPOINT_LOADED
optional uint64 size = 3;
optional uint64 length = 4;
// FAILED
// FAILED + CHECKPOINT_FAILED
optional string error_msg = 100;
}
......@@ -154,6 +164,7 @@ message ClientRequest {
// PLAN
optional Plan plan = 2;
optional bool load_checkpoints = 4;
// FETCH + RELEASE
optional int32 id = 3;
......
......@@ -12,27 +12,9 @@ import struct
import cloudpickle
import os
LOOM_PROTOCOL_VERSION = 2
class LoomException(Exception):
"""Base class for Loom exceptions"""
pass
class TaskFailed(LoomException):
"""Exception when scheduler informs about failure of a task"""
from .errors import LoomError, LoomException, TaskFailed # noqa
def __init__(self, id, worker, error_msg):
self.id = id
self.worker = worker
self.error_msg = error_msg
message = "Task id={} failed: {}".format(id, error_msg)
LoomException.__init__(self, message)
class LoomError(LoomException):
"""Generic error in Loom system"""
LOOM_PROTOCOL_VERSION = 2
class Client(object):
......@@ -259,7 +241,7 @@ class Client(object):
print(t)
assert 0
def submit_one(self, task):
def submit_one(self, task, load=False):
"""Submits a task to the server and returns a future
Args:
......@@ -274,9 +256,9 @@ class Client(object):
>>> result = client.submit(task3)
>>> print(result.gather())
"""
return self.submit((task,))[0]
return self.submit((task,), load=load)[0]
def submit(self, tasks):
def submit(self, tasks, load=False):
"""Submits tasks to the server and returns list of futures
Args:
......@@ -297,6 +279,7 @@ class Client(object):
futures = self.futures
results = []
for task in tasks:
task.validate()
if not isinstance(task, Task):
raise Exception("{} is not a task".format(task))
plan.add(task)
......@@ -311,7 +294,7 @@ class Client(object):
msg = ClientRequest()
msg.type = ClientRequest.PLAN
msg.load_checkpoints = load
include_metadata = self.trace_path is not None
msg.plan.id_base = id_base
plan.set_message(
......
class LoomException(Exception):
"""Base class for Loom exceptions"""
pass
class TaskFailed(LoomException):
"""Exception when scheduler informs about failure of a task"""
def __init__(self, id, worker, error_msg):
self.id = id
self.worker = worker
self.error_msg = error_msg
message = "Task id={} failed: {}".format(id, error_msg)
LoomException.__init__(self, message)
class LoomError(LoomException):
"""Generic error in Loom system"""
pass
......@@ -69,6 +69,8 @@ class Plan(object):
msg_t.task_type = symbols[task.task_type]
msg_t.input_ids.extend(self.get_id(t) for t in task.inputs)
msg_t.result = task in results
if task.checkpoint_path:
msg_t.checkpoint_path = task.checkpoint_path
if task.resource_request:
msg_t.resource_request_index = \
requests.index(task.resource_request)
......
import os.path
from .errors import LoomError
class Task(object):
......@@ -7,6 +10,12 @@ class Task(object):
resource_request = None
label = None
metadata = None
checkpoint_path = None
def validate(self):
if self.checkpoint_path is not None \
and not os.path.isabs(self.checkpoint_path):
raise LoomError("Checkpoint has to be absolute path")
def __repr__(self):
if self.label:
......
......@@ -61,3 +61,9 @@ int loom::base::make_path(const char *path, mode_t mode)
}
return 0;
}
bool loom::base::file_exists(const char *path)
{
struct stat buffer;
return (stat(path, &buffer) == 0);
}
......@@ -12,6 +12,7 @@ namespace base {
int make_path(const char *path, mode_t mode);
size_t file_size(const char *path);
bool file_exists(const char *path);
}
}
......
......@@ -48,6 +48,8 @@ add_library(libloomw
resalloc.cpp
task.cpp
task.h
checkpointwriter.h
checkpointwriter.cpp
wtrace.cpp
wtrace.h
taskdesc.h
......
#include "checkpointwriter.h"
#include "worker.h"
#include "libloom/log.h"
#include <stdio.h>
loom::CheckPointWriter::CheckPointWriter(loom::Worker &worker, base::Id id, const loom::DataPtr &data, const std::string &path)
: worker(worker), id(id), data(data), path(path)
{
work.data = this;
}
void loom::CheckPointWriter::start()
{
UV_CHECK(uv_queue_work(worker.get_loop(), &work, _work_cb, _after_work_cb));
}
void loom::CheckPointWriter::_work_cb(uv_work_t *req)
{
CheckPointWriter *writer = static_cast<CheckPointWriter*>(req->data);
std::string &path = writer->path;
std::string tmp_path = path + ".loom.tmp";
std::ofstream fout(tmp_path.c_str());
if (!fout.is_open()) {
writer->error = "Writing checkpoint '" + path + "' failed. Cannot create " + tmp_path + ":" + strerror(errno);
return;
}
const char *ptr = writer->data->get_raw_data();
if (!ptr) {
writer->error = "Data '" + writer->data->get_info() + "' cannot be checkpointed";
return;
}
fout.write(writer->data->get_raw_data(), writer->data->get_size());
fout.close();
if (rename(tmp_path.c_str(), path.c_str())) {
writer->error = "Writing checkpoint '" + path + "' failed. Cannot move " + tmp_path;
unlink(tmp_path.c_str());
}
}
void loom::CheckPointWriter::_after_work_cb(uv_work_t *req, int status)
{
UV_CHECK(status);
CheckPointWriter *writer = static_cast<CheckPointWriter*>(req->data);
if (writer->error.empty()) {
writer->worker.checkpoint_written(writer->id);
} else {
writer->worker.checkpoint_write_failed(writer->id, writer->error);
}
delete writer;
}
#ifndef LIBLOOMW_CHECKPOINTWRITER_H
#define LIBLOOMW_CHECKPOINTWRITER_H
#include "libloom/types.h"
#include "data.h"
namespace loom {
class CheckPointWriter {
public:
CheckPointWriter(Worker &worker, base::Id id, const DataPtr &data, const std::string &path);
void start();
protected:
uv_work_t work;
Worker &worker;
base::Id id;
DataPtr data;
std::string path;
std::string error;
private:
static void _work_cb(uv_work_t *req);
static void _after_work_cb(uv_work_t *req, int status);
};
}
#endif
......@@ -54,7 +54,7 @@ void InterConnection::finish_receive()
{
logger->debug("Interconnect: Data id={} received", unpacking_data_id);
worker.data_transferred(unpacking_data_id);
worker.publish_data(unpacking_data_id, unpacker->finish());
worker.publish_data(unpacking_data_id, unpacker->finish(), "");
auto &trace = worker.get_trace();
if (trace) {
......
......@@ -16,8 +16,8 @@ class Worker;
class Task {
public:
Task(base::Id id, int task_type, const std::string &config, int n_cpus)
: id(id), task_type(task_type), config(config), n_cpus(n_cpus), n_unresolved(0) {}
Task(base::Id id, int task_type, const std::string &config, int n_cpus, const std::string &checkpoint_path)
: id(id), task_type(task_type), config(config), n_cpus(n_cpus), n_unresolved(0), checkpoint_path(checkpoint_path) {}
Task(base::Id id, int task_type, std::string &&config, int n_cpus)
: id(id), task_type(task_type), config(std::move(config)), n_cpus(n_cpus), n_unresolved(0) {}
......@@ -47,6 +47,10 @@ public:
return n_cpus;
}
const std::string& get_checkpoint_path() const {
return checkpoint_path;
}
const std::vector<base::Id>& get_inputs() const {
return inputs;
}
......@@ -61,6 +65,7 @@ protected:
int n_cpus;
size_t n_unresolved;
std::unordered_set<base::Id> unresolved_set;
std::string checkpoint_path;
};
}
......
......@@ -54,9 +54,9 @@ void TaskInstance::fail_libuv(const std::string &error_msg, int error_code)
void TaskInstance::finish(const DataPtr &output)
{
assert(output);
worker.publish_data(get_id(), output);
worker.publish_data(get_id(), output, task->get_checkpoint_path());
assert(output);
worker.task_finished(*this, output);
worker.task_finished(*this, output, !task->get_checkpoint_path().empty());
}
void TaskInstance::redirect(std::unique_ptr<TaskDescription> tdesc)
......
......@@ -21,8 +21,10 @@
#include "libloom/sendbuffer.h"
#include "libloom/pbutils.h"
#include "libloom/fsutils.h"
#include "data/externfile.h"
#include "loom_define.h"
#include "checkpointwriter.h"
#include <stdlib.h>
#include <sstream>
......@@ -283,10 +285,59 @@ void Worker::start_task(std::unique_ptr<Task> task, ResourceAllocation &&ra)
t->start(input_data);
}
void Worker:: publish_data(Id id, const DataPtr &data)
void Worker::checkpoint_written(Id id) {
logger->debug("Checkpoint written id={}", id);
if (server_conn.is_connected()) {
loom::pb::comm::WorkerResponse msg;
msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_WRITTEN);
msg.set_id(id);
send_message(server_conn, msg);
}
}
void Worker::checkpoint_write_failed(Id id, const std::string &error_msg) {
logger->debug("Cannot write checkpoint id={}, error={}", id, error_msg);
if (server_conn.is_connected()) {
loom::pb::comm::WorkerResponse msg;
msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_WRITE_FAILED);
msg.set_id(id);
msg.set_error_msg(error_msg);
send_message(server_conn, msg);
}
}
void Worker::checkpoint_loaded(Id id, const DataPtr &data) {
logger->debug("Checkpoint loaded id={}", id);
if (server_conn.is_connected()) {
loom::pb::comm::WorkerResponse msg;
msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_LOADED);
msg.set_id(id);
msg.set_size(data->get_size());
msg.set_length(data->get_length());
send_message(server_conn, msg);
}
}
void Worker::checkpoint_load_failed(Id id, const std::string &error_msg) {
logger->debug("Cannot load checkpoint id={}, error={}", id, error_msg);
if (server_conn.is_connected()) {
loom::pb::comm::WorkerResponse msg;
msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_LOAD_FAILED);
msg.set_id(id);
msg.set_error_msg(error_msg);
send_message(server_conn, msg);
}
}
void Worker::publish_data(Id id, const DataPtr &data, const std::string &checkpoint_path)
{
logger->debug("Publishing data id={} size={} info={}", id, data->get_size(), data->get_info());
public_data[id] = data;
if (!checkpoint_path.empty()) {
write_checkpoint(id, data, checkpoint_path);
}
check_waiting_tasks(id);
}
......@@ -480,6 +531,20 @@ void Worker::remove_task(TaskInstance &task, bool free_resources)
assert(0);
}
void Worker::load_checkpoint(Id id, const std::string &path) {
if (!file_exists(path.c_str())) {
std::stringstream s;
s << "File '" << path << "' does not exists";
std::string error = s.str();
logger->error("Cannot load checkpoint {}: {}", id, error);
checkpoint_load_failed(id, error);
return;
}
DataPtr data = std::make_shared<ExternFile>(path);
checkpoint_loaded(id, data);
publish_data(id, data, "");
}
void Worker::task_failed(TaskInstance &task, const std::string &error_msg)
{
logger->error("Task id={} failed: {}", task.get_id(), error_msg);
......@@ -519,12 +584,18 @@ void Worker::task_redirect(TaskInstance &task,
t->start(new_task_desc->inputs);
}
void Worker::task_finished(TaskInstance &task, const DataPtr &data)
void Worker::write_checkpoint(Id id, const DataPtr &data, const std::string &checkpoint_path)
{
CheckPointWriter *writer = new CheckPointWriter(*this, id, data, checkpoint_path);
writer->start();
}
void Worker::task_finished(TaskInstance &task, const DataPtr &data, bool checkpointing)
{
using namespace loom::pb::comm;
if (server_conn.is_connected()) {
WorkerResponse msg;
msg.set_type(WorkerResponse_Type_FINISHED);
msg.set_type(checkpointing ? WorkerResponse_Type_FINISHED_AND_CHECKPOINTING : WorkerResponse_Type_FINISHED);
msg.set_id(task.get_id());
msg.set_size(data->get_size());
msg.set_length(data->get_length());
......@@ -571,7 +642,8 @@ void Worker::on_message(const char *data, size_t size)
auto task = std::make_unique<Task>(msg.id(),
msg.task_type(),
msg.task_config(),
msg.n_cpus());
msg.n_cpus(),
msg.checkpoint_path());
for (int i = 0; i < msg.task_inputs_size(); i++) {
Id task_id = msg.task_inputs(i);
task->add_input(task_id);
......@@ -603,6 +675,12 @@ void Worker::on_message(const char *data, size_t size)
}
break;
}
case comm::WorkerCommand_Type_LOAD_CHECKPOINT: {
logger->debug("Loading checkpoint id={} path={}", msg.id(), msg.checkpoint_path());
assert(msg.has_checkpoint_path());
load_checkpoint(msg.id(), msg.checkpoint_path());
break;
}
case comm::WorkerCommand_Type_DICTIONARY: {
auto count = msg.symbols_size();
logger->debug("New dictionary ({} symbols)", count);
......
......@@ -52,12 +52,13 @@ public:
return true;
}
void task_finished(TaskInstance &task_instance, const DataPtr &data);
void task_finished(TaskInstance &task_instance, const DataPtr &data, bool checkpointing);
void task_failed(TaskInstance &task_instance, const std::string &error_msg);
void data_transferred(base::Id task_id);
void task_redirect(TaskInstance &task, std::unique_ptr<TaskDescription> new_task_desc);
void publish_data(base::Id id, const DataPtr &data);
void publish_data(base::Id id, const DataPtr &data, const std::string &checkpoint_path);
void write_checkpoint(base::Id id, const DataPtr &data, const std::string &checkpoint_path);
void remove_data(base::Id id);
bool has_data(base::Id id) const
......@@ -126,8 +127,16 @@ public:
return trace;
}
void on_dictionary_updated();
void load_checkpoint(base::Id id, const std::string &path);
void checkpoint_written(base::Id id);
void checkpoint_write_failed(base::Id id, const std::string &error_msg);
void checkpoint_loaded(base::Id id, const DataPtr &Data);
void checkpoint_load_failed(base::Id id, const std::string &error_msg);
private:
void register_worker();
void create_trace(const std::string &trace_path, loom::base::Id worker_id);
......
......@@ -56,8 +56,8 @@ void ClientConnection::on_message(const char *buffer, size_t size)
case ClientRequest_Type_PLAN: {
logger->debug("Plan received");
const Plan &plan = request.plan();
loom::base::Id id_base = task_manager.add_plan(plan);
logger->info("Plan submitted tasks={}", plan.tasks_size());
loom::base::Id id_base = task_manager.add_plan(plan, request.load_checkpoints());
logger->info("Plan submitted tasks={}, load_checkpoints={}", plan.tasks_size(), request.load_checkpoints());
if (server.get_trace()) {
server.create_file_in_trace_dir(std::to_string(id_base) + ".plan", buffer, size);
......
......@@ -5,6 +5,7 @@
#include "pb/comm.pb.h"
#include "libloom/log.h"
#include "libloom/fsutils.h"
constexpr static double TRANSFER_COST_COEF = 1.0 / (1024 * 1024); // 1MB = 1cost
......@@ -22,16 +23,34 @@ ComputationState::ComputationState(Server &server) : server(server)
void ComputationState::add_node(std::unique_ptr<TaskNode> &&node) {
auto id = node->get_id();
for (TaskNode* input_node : node->get_inputs()) {
input_node->add_next(node.get());
auto result = nodes.insert(std::make_pair(id, std::move(node)));
assert(result.second); // Check that ID is fresh
}
void ComputationState::plan_node(TaskNode &node, bool load_checkpoints, std::vector<TaskNode *> &to_load) {
if (node.is_planned()) {
return;
}
node.set_planned();
if (node->is_ready()) {
pending_nodes.insert(node.get());
if (load_checkpoints && !node.get_task_def().checkpoint_path.empty() && loom::base::file_exists(node.get_task_def().checkpoint_path.c_str())) {
node.set_checkpoint();
to_load.push_back(&node);
return;
}
auto result = nodes.insert(std::make_pair(id, std::move(node)));
assert(result.second); // Check that ID is fresh
int remaining_inputs = 0;
for (TaskNode *input_node : node.get_inputs()) {
plan_node(*input_node, load_checkpoints, to_load);
if (!input_node->is_computed()) {
remaining_inputs += 1;
input_node->add_next(&node);
}
}
node.set_remaining_inputs(remaining_inputs);
if (remaining_inputs == 0) {
pending_nodes.insert(&node);
}
}
void ComputationState::reserve_new_nodes(size_t size)
......@@ -241,7 +260,7 @@ void ComputationState::make_expansion(std::vector<std::string> &configs,
}
}*/
loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan)
loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints, std::vector<TaskNode*> &to_load)
{
auto task_size = plan.tasks_size();
assert(plan.has_id_base());
......@@ -267,8 +286,11 @@ loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan)
def.task_type = pt.task_type();
def.config = pt.config();
def.checkpoint_path = pt.checkpoint_path();
bool is_result = false;
if (pt.has_result() && pt.result()) {
def.flags.set(static_cast<size_t>(TaskFlags::RESULT));
is_result = true;
def.flags.set(static_cast<size_t>(TaskDefFlags::RESULT));
}
auto inputs_size = pt.input_ids_size();
......@@ -283,7 +305,12 @@ loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan)
n_cpus = resources[pt.resource_request_index()];
}
def.n_cpus = n_cpus;
add_node(std::make_unique<TaskNode>(id, std::move(def)));
auto new_node = std::make_unique<TaskNode>(id, std::move(def));
if (is_result) {
plan_node(*new_node.get(), load_checkpoints, to_load);
}
add_node(std::move(new_node));
}
return id_base;
}
......
......@@ -41,11 +41,9 @@ public:
void remove_node(TaskNode &node);
bool is_ready(const TaskNode &node);
int get_n_data_objects() const;
loom::base::Id add_plan(const loom::pb::comm::Plan &plan);
loom::base::Id add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints, std::vector<TaskNode *> &to_load);
void test_ready_nodes(std::vector<loom::base::Id> ids);
loom::base::Id pop_result_client_id(loom::base::Id id);
......@@ -62,8 +60,9 @@ public:
std::unique_ptr<TaskNode> pop_node(loom::base::Id id);
void clear_all();
void add_pending_node(TaskNode &node);
void add_pending_node(TaskNode &node);
void plan_node(TaskNode &node, bool load_checkpoints, std::vector<TaskNode*> &to_load);
void fail_task_on_worker(WorkerConnection &conn);
private:
std::unordered_map<loom::base::Id, std::unique_ptr<TaskNode>> nodes;
std::unordered_set<TaskNode*> pending_nodes;
......
......@@ -128,7 +128,7 @@ static void compute_table(const TaskNode *node,
if (node == input_node) {
continue;
}
if (input_node->has_state()) {
if (input_node->is_computed()) {
Score score = score_from_next_size(input_node->get_size());
for (const auto &pair : input_node->get_workers()) {
WorkerConnection *wc = pair.first;
......@@ -340,7 +340,7 @@ TaskDistribution schedule(const ComputationState &cstate)
continue;
}
for (TaskNode *input_node : next_node->get_inputs()) {
if (input_node == best_node || input_node->has_state()) {
if (input_node == best_node || input_node->is_computed()) {
continue;
}
auto it = context.units.find(input_node->get_id());
......
......@@ -48,6 +48,7 @@ void Server::add_worker_connection(std::unique_ptr<WorkerConnection> &&conn)
void Server::remove_worker_connection(WorkerConnection &conn)
{
task_manager.worker_fail(conn);
auto i = std::find_if(
connections.begin(),
connections.end(),
......@@ -81,9 +82,29 @@ void Server::remove_freshconnection(FreshConnection &conn)
fresh_connections.erase(i);
}
void Server::on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc)
void Server::on_checkpoint_write_finished(loom::base::Id id, WorkerConnection *wc)
{
task_manager.on_task_finished(id, size, length, wc);
task_manager.on_checkpoint_write_finished(id, wc);
}
void Server::on_checkpoint_write_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg)
{
task_manager.on_checkpoint_write_failed(id, wc, error_msg);
}
void Server::on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc, size_t size, size_t length)
{
task_manager.on_checkpoint_load_finished(id, wc, size, length);
}
void Server::on_checkpoint_load_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg)
{
task_manager.on_checkpoint_load_failed(id, wc, error_msg);
}
void Server::on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing)
{
task_manager.on_task_finished(id, size, length, wc, checkpointing);
}
void Server::on_data_transferred(loom::base::Id id, WorkerConnection *wc)
......
......@@ -56,7 +56,7 @@ public:
void add_resend_task(loom::base::Id id);
void on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc);
void on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing);
void on_data_transferred(loom::base::Id id, WorkerConnection *wc);
loom::base::Dictionary& get_dictionary() {
......@@ -88,6 +88,10 @@ public:
return trace;
}
void on_checkpoint_write_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
void on_checkpoint_write_finished(loom::base::Id id, WorkerConnection *wc);
void on_checkpoint_load_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
void on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc, size_t size, size_t length);
private:
......
......@@ -9,6 +9,8 @@
#include <algorithm>
#include <assert.h>
#include <limits.h>
#include <stdlib.h>
#include <memory>
using namespace loom;
using namespace loom::base;
......@@ -18,9 +20,15 @@ TaskManager::TaskManager(Server &server)
{
}
loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan)
loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints)
{
loom::base::Id id_base = cstate.add_plan(plan);
std::vector<TaskNode*> to_load;
loom::base::Id id_base = cstate.add_plan(plan, load_checkpoints, to_load);
for (TaskNode *node : to_load) {
WorkerConnection *wc = random_worker();
node->set_as_loading(wc);
wc->load_checkpoint(node->get_id(), node->get_task_def().checkpoint_path);
}
distribute_work(schedule(cstate));
return id_base;
}
......@@ -76,14 +84,16 @@ void TaskManager::remove_node(TaskNode &node)
assert(status == TaskStatus::OWNER);
wc->remove_data(id);
});
node.set_not_needed();
//cstate.remove_node(node);
node.reset_owners();
cstate.remove_node(node);
}
void TaskManager::on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc)
void TaskManager::on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing)
{
if (unlikely(wc->is_blocked())) {
wc->residual_task_finished(id, true);
wc->residual_task_finished(id, true, checkpointing);
return;
}
TaskNode &node = cstate.get_node(id);
......@@ -94,35 +104,39 @@ void TaskManager::on_task_finished(loom::base::Id id, size_t size, size_t length
trace->trace_task_end(*node, wc);
}*/
if (node.is_result()) {
// Task with checkpoints is "finished" for worker only if
// checkpoint is written
if (!checkpointing && node.is_result()) {
logger->debug("Job id={} [RESULT] finished", id);
/*WorkerConnection *owner = node->get_random_owner();
assert(owner);
owner->send_data(id, server.get_dummy_worker().get_address());*/
ClientConnection *cc = server.get_client_connection();
if (cc) {
cc->send_info_about_finished_result(node);
cc->send_info_about_finished_result(node);
}
} else {
assert(!node.get_nexts().empty());
assert(checkpointing || !node.get_nexts().empty());
logger->debug("Job id={} finished (size={}, length={})", id, size, length);
}
if (checkpointing) {
wc->change_checkpoint_writes(1);
}
for (TaskNode *input_node : node.get_inputs()) {
if (input_node->next_finished(node) && !input_node->is_result()) {
remove_node(*input_node);
}
}
if (!node.get_nexts().empty()) {
for (TaskNode *nn : node.get_nexts()) {
if (nn->input_is_ready(&node)) {
cstate.add_pending_node(*nn);
}
}
} /*else {
remove_node(*node);
}*/
} else if (!node.is_result()) {
remove_node(node);
}
if (cstate.has_pending_nodes()) {
server.need_task_distribution();
......@@ -132,7 +146,7 @@ void TaskManager::on_task_finished(loom::base::Id id, size_t size, size_t length
void TaskManager::on_data_transferred(Id id, WorkerConnection *wc)
{
if (unlikely(wc->is_blocked())) {
wc->residual_task_finished(id, true);
wc->residual_task_finished(id, true, false);
return;
}
TaskNode &node = cstate.get_node(id);
......@@ -143,7 +157,7 @@ void TaskManager::on_data_transferred(Id id, WorkerConnection *wc)
void TaskManager::on_task_failed(Id id, WorkerConnection *wc, const std::string &error_msg)
{
if (unlikely(wc->is_blocked())) {
wc->residual_task_finished(id, false);
wc->residual_task_finished(id, false, false);
return;
}
logger->error("Task id={} failed on worker {}: {}",
......@@ -163,6 +177,92 @@ void TaskManager::on_task_failed(Id id, WorkerConnection *wc, const std::string
}
}
void TaskManager::on_checkpoint_write_finished(Id id, WorkerConnection *wc)
{
if (unlikely(wc->is_blocked())) {
wc->residual_checkpoint_finished(id);
return;
}
logger->debug("Checkpoint id={} finished on worker {}", id, wc->get_address());
wc->change_checkpoint_writes(-1);
TaskNode &node = cstate.get_node(id);
assert(node.has_defined_checkpoint());
node.set_checkpoint();
if (node.is_result()) {
ClientConnection *cc = server.get_client_connection();
if (cc) {
cc->send_info_about_finished_result(node);
}
}
}
void TaskManager::on_checkpoint_load_finished(Id id, WorkerConnection *wc, size_t size, size_t length)
{
if (unlikely(wc->is_blocked())) {
wc->residual_task_finished(id, false, false);
return;
}
wc->change_checkpoint_loads(-1);
TaskNode &node = cstate.get_node(id);
node.set_as_loaded(wc, size, length);
if (node.is_result()) {
logger->debug("Task id={} [RESULT] checkpoint loaded", id);
ClientConnection *cc = server.get_client_connection();
if (cc) {
cc->send_info_about_finished_result(node);
}
} else {
logger->debug("Task id={} checkpoint loaded", id);
}
if (!node.get_nexts().empty()) {
for (TaskNode *nn : node.get_nexts()) {
if (nn->input_is_ready(&node)) {
cstate.add_pending_node(*nn);
}
}
} else if (!node.is_result()) {
remove_node(node);
}
if (cstate.has_pending_nodes()) {
server.need_task_distribution();
}
}
void TaskManager::on_checkpoint_write_failed(Id id, WorkerConnection *wc, const std::string &error_msg)
{
wc->change_checkpoint_writes(-1);
if (unlikely(wc->is_blocked())) {
return;
}
logger->error("Checkpoint id={} failed on worker {}: {}",
id, wc->get_address(), error_msg);
auto cc = server.get_client_connection();
if (cc) {
cc->send_task_failed(id, *wc, error_msg);
}
trash_all_tasks();
}
void TaskManager::on_checkpoint_load_failed(Id id, WorkerConnection *wc, const std::string &error_msg)
{
if (unlikely(wc->is_blocked())) {
wc->residual_task_finished(id, false, false);
return;
}
wc->change_checkpoint_loads(-1);
logger->error("Checkpoint id={} load failed on worker {}: {}",
id, wc->get_address(), error_msg);
auto cc = server.get_client_connection();
if (cc) {
cc->send_task_failed(id, *wc, error_msg);
}
}
void TaskManager::run_task_distribution()
{
uv_loop_t *loop = server.get_loop();
......@@ -198,25 +298,26 @@ void TaskManager::run_task_distribution()
}
void TaskManager::trash_all_tasks()
{
{
for (auto &wc : server.get_connections()) {
wc->change_residual_tasks(wc->get_checkpoint_loads());
wc->change_checkpoint_loads(-wc->get_checkpoint_loads());
}
cstate.foreach_node([](std::unique_ptr<TaskNode> &task) {
if (task->has_state()) {
task->foreach_worker([&task](WorkerConnection *wc, TaskStatus &status) {
if (status == TaskStatus::OWNER) {
wc->remove_data(task->get_id());
} else if (status == TaskStatus::RUNNING) {
wc->change_residual_tasks(1);
wc->free_resources(*task);
logger->debug("Residual task id={} on worker={}", task->get_id(), wc->get_worker_id());
} else {
assert(status == TaskStatus::TRANSFER);
wc->change_residual_tasks(1);
logger->debug("Residual transfer id={} on worker={}", task->get_id(), wc->get_worker_id());
}
status = TaskStatus::NONE;
});
}
task->foreach_worker([&task](WorkerConnection *wc, TaskStatus status) {
if (status == TaskStatus::OWNER) {
wc->remove_data(task->get_id());
} else if (status == TaskStatus::RUNNING) {
wc->change_residual_tasks(1);
wc->free_resources(*task);
logger->debug("Residual task id={} on worker={}", task->get_id(), wc->get_worker_id());
} else {
assert(status == TaskStatus::TRANSFER);
wc->change_residual_tasks(1);
logger->debug("Residual transfer id={} on worker={}", task->get_id(), wc->get_worker_id());
}
status = TaskStatus::NONE;
});
});
cstate.clear_all();
}
......@@ -231,3 +332,20 @@ void TaskManager::release_node(TaskNode *node)
node->reset_result_flag();
}
}
void TaskManager::worker_fail(WorkerConnection &conn)
{
auto cc = server.get_client_connection();
if (cc) {
cc->send_task_failed(-1, conn, "Worker lost");
}
trash_all_tasks();
}
WorkerConnection *TaskManager::random_worker()
{
auto &connections = server.get_connections();
assert(!connections.empty());
int index = rand() % connections.size();
return connections[index].get();
}
......@@ -23,11 +23,15 @@ public:
cstate.add_node(std::move(node));
}*/
loom::base::Id add_plan(const loom::pb::comm::Plan &plan);
loom::base::Id add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints);
void on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc);
void on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing);
void on_data_transferred(loom::base::Id id, WorkerConnection *wc);
void on_task_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
void on_checkpoint_write_finished(loom::base::Id id, WorkerConnection *wc);
void on_checkpoint_write_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
void on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc, size_t size, size_t length);
void on_checkpoint_load_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
int get_n_of_data_objects() const {
return cstate.get_n_data_objects();
......@@ -42,6 +46,10 @@ public:
void trash_all_tasks();
void release_node(TaskNode *node);
void fail_task_on_worker(WorkerConnection &conn);
void worker_fail(WorkerConnection &conn);
WorkerConnection *random_worker();
private:
Server &server;
......
......@@ -6,7 +6,11 @@
#include <sstream>
TaskNode::TaskNode(loom::base::Id id, TaskDef &&task)
: id(id), task(std::move(task))
: id(id),
task(std::move(task)),
size(0),
length(0),
remaining_inputs(0)
{
}
......@@ -14,27 +18,12 @@ TaskNode::TaskNode(loom::base::Id id, TaskDef &&task)
void TaskNode::reset_result_flag()
{
assert(is_result());
task.flags.reset(static_cast<size_t>(TaskFlags::RESULT));
}
bool TaskNode::is_computed() const {
if (!state) {
return false;
}
for(auto &pair : state->workers) {
if (pair.second == TaskStatus::OWNER) {
return true;
}
}
return false;
task.flags.reset(static_cast<size_t>(TaskDefFlags::RESULT));
}
WorkerConnection *TaskNode::get_random_owner()
{
if (!state) {
return nullptr;
}
for(auto &pair : state->workers) {
for(auto &pair : workers) {
if (pair.second == TaskStatus::OWNER) {
return pair.first;
}
......@@ -42,41 +31,9 @@ WorkerConnection *TaskNode::get_random_owner()
return nullptr;
}
void TaskNode::create_state(TaskNode *just_finishing_input)
{
/* just_finishing_input has to be introduced
* to solve situation when a task takes the same input multiple times
* and an input is finished. When we create a state in situation without
* finishing a task, than just_finishing_input has to be nullptr
*/
assert(!state);
state = std::make_unique<RuntimeState>();
state->size = 0;
state->length = 0;
int remaining_inputs = 0;
if (just_finishing_input) {
for (TaskNode *input_node : task.inputs) {
if (just_finishing_input == input_node || !input_node->is_computed()) {
remaining_inputs += 1;
}
}
} else {
for (TaskNode *input_node : task.inputs) {
if (!input_node->is_computed()) {
remaining_inputs += 1;
}
}
}
state->remaining_inputs = remaining_inputs;
}
bool TaskNode::is_active() const
{
if (!state) {
return false;
}
for (auto &pair : state->workers) {
for (auto &pair : workers) {
if (pair.second == TaskStatus::RUNNING || pair.second == TaskStatus::TRANSFER) {
return true;
}
......@@ -86,7 +43,7 @@ bool TaskNode::is_active() const
void TaskNode::reset_owners()
{
for (auto &pair : state->workers) {
for (auto &pair : workers) {
if (pair.second == TaskStatus::OWNER) {
pair.second = TaskStatus::NONE;
}
......@@ -115,22 +72,36 @@ void TaskNode::set_as_finished(WorkerConnection *wc, size_t size, size_t length)
assert(get_worker_status(wc) == TaskStatus::RUNNING);
wc->free_resources(*this);
set_worker_status(wc, TaskStatus::OWNER);
state->size = size;
state->length = length;
this->size = size;
this->length = length;
flags.set(static_cast<size_t>(TaskNodeFlags::FINISHED));
}
void TaskNode::set_as_loaded(WorkerConnection *wc, size_t size, size_t length) {
assert(get_worker_status(wc) == TaskStatus::LOADING);
set_worker_status(wc, TaskStatus::OWNER);
this->size = size;
this->length = length;
flags.set(static_cast<size_t>(TaskNodeFlags::FINISHED));
}
void TaskNode::set_as_loading(WorkerConnection *wc) {
set_worker_status(wc, TaskStatus::LOADING);
}
void TaskNode::set_as_finished_no_check(WorkerConnection *wc, size_t size, size_t length)
{
set_worker_status(wc, TaskStatus::OWNER);
state->size = size;
state->length = length;
this->size = size;
this->length = length;
flags.set(static_cast<size_t>(TaskNodeFlags::FINISHED));
}
std::string TaskNode::debug_str() const
{
std::stringstream s;
s << "<Node id=" << id;
for(auto &pair : state->workers) {
for(auto &pair : workers) {
s << ' ' << pair.first->get_address() << ':' << static_cast<int>(pair.second);
}
s << '>';
......@@ -156,8 +127,7 @@ void TaskNode::set_as_running(WorkerConnection *wc)
void TaskNode::set_as_transferred(WorkerConnection *wc)
{
assert(state);
auto &s = state->workers[wc];
auto &s = workers[wc];
assert(s == TaskStatus::TRANSFER);
s = TaskStatus::OWNER;
}
......@@ -14,10 +14,17 @@
class WorkerConnection;
class TaskNode;
enum class TaskFlags : size_t {
enum class TaskDefFlags : size_t {
RESULT
};
enum class TaskNodeFlags : size_t {
FINISHED,
CHECKPOINT,
PLANNED,
FLAGS_COUNT
};
struct TaskDef
{
int n_cpus; // TODO: Replace by resource index
......@@ -25,12 +32,14 @@ struct TaskDef
loom::base::Id task_type;
std::string config;
std::bitset<1> flags;
std::string checkpoint_path;
};
enum class TaskStatus {
NONE,
RUNNING,
TRANSFER,
LOADING,
OWNER,
};
......@@ -42,37 +51,48 @@ class TaskNode {
public:
struct RuntimeState {
WorkerMap<TaskStatus> workers;
size_t size;
size_t length;
size_t remaining_inputs;
};
TaskNode(loom::base::Id id, TaskDef &&task);
loom::base::Id get_id() const {
return id;
}
bool has_state() const {
return state != nullptr;
inline bool is_result() const {
return task.flags.test(static_cast<size_t>(TaskDefFlags::RESULT));
}
inline bool is_result() const {
return task.flags.test(static_cast<size_t>(TaskFlags::RESULT));
bool is_computed() const {
return flags.test(static_cast<size_t>(TaskNodeFlags::FINISHED));
}
inline bool has_checkpoint() const {
return flags.test(static_cast<size_t>(TaskNodeFlags::CHECKPOINT));
}
bool has_defined_checkpoint() const {
return !task.checkpoint_path.empty();
}
void set_checkpoint() {
flags.set(static_cast<size_t>(TaskNodeFlags::CHECKPOINT));
}
bool is_planned() const {
return flags.test(static_cast<size_t>(TaskNodeFlags::PLANNED));
}
void set_planned() {
flags.set(static_cast<size_t>(TaskNodeFlags::PLANNED));
}
void reset_result_flag();
inline size_t get_size() const {
//assert(state);
return state->size;
return size;
}
inline size_t get_length() const {
//assert(state);
return state->length;
return length;
}
int get_n_cpus() const {
......@@ -91,7 +111,6 @@ public:
return nexts;
}
bool is_computed() const;
bool is_active() const;
WorkerConnection* get_random_owner();
......@@ -100,49 +119,36 @@ public:
}
TaskStatus get_worker_status(WorkerConnection *wc) {
if (!state) {
return TaskStatus::NONE;
}
auto i = state->workers.find(wc);
if (i == state->workers.end()) {
auto i = workers.find(wc);
if (i == workers.end()) {
return TaskStatus::NONE;
}
return i->second;
}
void set_worker_status(WorkerConnection *wc, TaskStatus status) {
ensure_state();
state->workers[wc] = status;
workers[wc] = status;
}
inline void ensure_state() {
if (!state) {
create_state();
}
void set_remaining_inputs(int value) {
remaining_inputs = value;
}
int get_remaining_inputs() const {
return remaining_inputs;
}
inline bool input_is_ready(TaskNode *node) {
if (!state) {
create_state(node);
}
assert(state->remaining_inputs > 0);
return --state->remaining_inputs == 0;
assert(remaining_inputs > 0);
return --remaining_inputs == 0;
}
inline bool is_ready() const {
if (!state) {
return _slow_is_ready();
}
return state->remaining_inputs == 0;
return remaining_inputs == 0;
}
void create_state(TaskNode *just_finishing_input = nullptr);
template<typename F> inline void foreach_owner(const F &f) const {
if (!state) {
return;
}
for(auto &pair : state->workers) {
for(auto &pair : workers) {
if (pair.second == TaskStatus::OWNER) {
f(pair.first);
}
......@@ -150,10 +156,7 @@ public:
}
template<typename F> inline void foreach_worker(const F &f) const {
if (!state) {
return;
}
for(auto &pair : state->workers) {
for(auto &pair : workers) {
if (pair.second != TaskStatus::NONE) {
f(pair.first, pair.second);
}
......@@ -163,27 +166,40 @@ public:
void reset_owners();
const WorkerMap<TaskStatus>& get_workers() const {
return state->workers;
return workers;
}
bool next_finished(TaskNode &);
void set_as_finished(WorkerConnection *wc, size_t size, size_t length);
void set_as_loaded(WorkerConnection *wc, size_t size, size_t length);
void set_as_running(WorkerConnection *wc);
void set_as_loading(WorkerConnection *wc);
void set_as_transferred(WorkerConnection *wc);
void set_as_none(WorkerConnection *wc);
// For unit testing
void set_as_finished_no_check(WorkerConnection *wc, size_t size, size_t length);
void set_not_needed() {
flags.reset(static_cast<size_t>(TaskNodeFlags::PLANNED));
flags.reset(static_cast<size_t>(TaskNodeFlags::FINISHED));
}
std::string debug_str() const;
private:
// Declaration
loom::base::Id id;
TaskDef task;
std::unordered_multiset<TaskNode*> nexts;
std::unique_ptr<RuntimeState> state;
// Runtime info
std::bitset<3> flags;
WorkerMap<TaskStatus> workers;
size_t size;
size_t length;
size_t remaining_inputs;
bool _slow_is_ready() const;
};
......
......@@ -23,7 +23,9 @@ WorkerConnection::WorkerConnection(Server &server,
task_types(task_types),
data_types(data_types),
worker_id(worker_id),
n_residual_tasks(0)
n_residual_tasks(0),
checkpoint_writes(0),
checkpoint_loads(0)
{
logger->info("Worker {} connected (cpus={})", address, resource_cpus);
if (this->socket) {
......@@ -48,21 +50,47 @@ void WorkerConnection::on_message(const char *buffer, size_t size)
WorkerResponse msg;
msg.ParseFromArray(buffer, size);
if (msg.type() == WorkerResponse_Type_FINISHED) {
server.on_task_finished(msg.id(), msg.size(), msg.length(), this);
auto type = msg.type();
if (type == WorkerResponse_Type_FINISHED) {
server.on_task_finished(msg.id(), msg.size(), msg.length(), this, false);
return;
}
if (msg.type() == WorkerResponse_Type_TRANSFERED) {
if (type == WorkerResponse_Type_FINISHED_AND_CHECKPOINTING) {
server.on_task_finished(msg.id(), msg.size(), msg.length(), this, true);
return;
}
if (type == WorkerResponse_Type_TRANSFERED) {
server.on_data_transferred(msg.id(), this);
return;
}
if (msg.type() == WorkerResponse_Type_FAILED) {
if (type == WorkerResponse_Type_FAILED) {
assert(msg.has_error_msg());
server.on_task_failed(msg.id(), this, msg.error_msg());
return;
}
if (type == WorkerResponse_Type_CHECKPOINT_WRITTEN) {
server.on_checkpoint_write_finished(msg.id(), this);
return;
}
if (type == WorkerResponse_Type_CHECKPOINT_WRITE_FAILED) {
server.on_checkpoint_write_failed(msg.id(), this, msg.error_msg());
return;
}
if (type == WorkerResponse_Type_CHECKPOINT_LOADED) {
server.on_checkpoint_load_finished(msg.id(), this, msg.size(), msg.length());
return;
}
if (type == WorkerResponse_Type_CHECKPOINT_LOAD_FAILED) {
server.on_checkpoint_load_failed(msg.id(), this, msg.error_msg());
return;
}
}
void WorkerConnection::send_task(const TaskNode &task)
......@@ -79,6 +107,7 @@ void WorkerConnection::send_task(const TaskNode &task)
msg.set_task_type(def.task_type);
msg.set_task_config(def.config);
msg.set_n_cpus(def.n_cpus);
msg.set_checkpoint_path(def.checkpoint_path);
for (TaskNode *input_node : task.get_inputs()) {
msg.add_task_inputs(input_node->get_id());
......@@ -98,6 +127,19 @@ void WorkerConnection::send_data(Id id, const std::string &address)
send_message(*socket, msg);
}
void WorkerConnection::load_checkpoint(Id id, const std::string &checkpoint_path)
{
checkpoint_loads += 1;
using namespace loom::pb::comm;
logger->debug("Command for {}: LOAD_CHECKPOINT id={} path={}", this->address, id, checkpoint_path);
WorkerCommand msg;
msg.set_type(WorkerCommand_Type_LOAD_CHECKPOINT);
msg.set_id(id);
msg.set_checkpoint_path(checkpoint_path);
send_message(*socket, msg);
}
void WorkerConnection::remove_data(Id id)
{
using namespace loom::pb::comm;
......@@ -113,10 +155,15 @@ void WorkerConnection::free_resources(TaskNode &node)
add_free_cpus(node.get_n_cpus());
}
void WorkerConnection::residual_task_finished(Id id, bool success)
void WorkerConnection::residual_task_finished(Id id, bool success, bool checkpointing)