Commit ae1bb892 authored by Stanislav Bohm's avatar Stanislav Bohm

WIP

parent d50e3aaf
......@@ -63,6 +63,7 @@ message WorkerCommand {
REMOVE = 3;
DICTIONARY = 8;
UPDATE = 9;
LOAD_CHECKPOINT = 10;
}
required Type type = 1;
......@@ -74,6 +75,8 @@ message WorkerCommand {
optional string task_config = 4;
repeated int32 task_inputs = 5;
optional int32 n_cpus = 6;
// TASK + LOAD_CHECKPOINT
optional string checkpoint_path = 7;
// SEND
......@@ -101,7 +104,7 @@ message WorkerResponse {
required Type type = 1;
required int32 id = 2;
// FINISHED
// FINISHED + CHECKPOINT_LOADED
optional uint64 size = 3;
optional uint64 length = 4;
......
......@@ -50,8 +50,6 @@ add_library(libloomw
task.h
checkpointwriter.h
checkpointwriter.cpp
checkpointloader.h
checkpointloader.cpp
wtrace.cpp
wtrace.h
taskdesc.h
......
#include "checkpointloader.h"
#include "worker.h"
#include "libloom/log.h"
#include <stdio.h>
loom::CheckPointLoader::CheckPointLoader(loom::Worker &worker, base::Id id, const std::string &path)
: worker(worker), id(id), path(path)
{
work.data = this;
}
void loom::CheckPointLoader::start()
{
UV_CHECK(uv_queue_work(worker.get_loop(), &work, _work_cb, _after_work_cb));
}
void loom::CheckPointLoader::_work_cb(uv_work_t *req)
{
CheckPointLoader *loader = static_cast<CheckPointLoader*>(req->data);
loader->error = "TODO";
/*
std::string &path = writer->path;
std::string tmp_path = path + ".loom.tmp";
std::ofstream fout(tmp_path.c_str());
if (!fout.is_open()) {
writer->error = "Writing checkpoint '" + path + "' failed. Cannot create " + tmp_path + ":" + strerror(errno);
return;
}
const char *ptr = writer->data->get_raw_data();
if (!ptr) {
writer->error = "Data '" + writer->data->get_info() + "' cannot be checkpointed";
return;
}
fout.write(writer->data->get_raw_data(), writer->data->get_size());
fout.close();
if (rename(tmp_path.c_str(), path.c_str())) {
writer->error = "Writing checkpoint '" + path + "' failed. Cannot move " + tmp_path;
unlink(tmp_path.c_str());
}*/
}
void loom::CheckPointLoader::_after_work_cb(uv_work_t *req, int status)
{
UV_CHECK(status);
CheckPointLoader *loader = static_cast<CheckPointLoader*>(req->data);
if (loader->error.empty()) {
loader->worker.checkpoint_loaded(loader->id);
} else {
loader->worker.checkpoint_load_failed(loader->id, loader->error);
}
delete loader;
}
#ifndef LIBLOOMW_CHECKPOINTWRITER_H
#define LIBLOOMW_CHECKPOINTWRITER_H
#include "libloom/types.h"
#include "data.h"
namespace loom {
class CheckPointLoader {
public:
CheckPointLoader(Worker &worker, base::Id id, const std::string &path);
void start();
protected:
uv_work_t work;
Worker &worker;
base::Id id;
std::string path;
std::string error;
private:
static void _work_cb(uv_work_t *req);
static void _after_work_cb(uv_work_t *req, int status);
};
}
#endif
......@@ -21,6 +21,7 @@
#include "libloom/sendbuffer.h"
#include "libloom/pbutils.h"
#include "libloom/fsutils.h"
#include "data/externfile.h"
#include "loom_define.h"
#include "checkpointwriter.h"
......@@ -305,12 +306,14 @@ void Worker::checkpoint_write_failed(Id id, const std::string &error_msg) {
}
}
void Worker::checkpoint_loaded(Id id) {
void Worker::checkpoint_loaded(Id id, const DataPtr &data) {
logger->debug("Checkpoint loaded id={}", id);
if (server_conn.is_connected()) {
loom::pb::comm::WorkerResponse msg;
msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_LOADED);
msg.set_id(id);
msg.set_size(data->get_size());
msg.set_length(data->get_length());
send_message(server_conn, msg);
}
}
......@@ -326,7 +329,7 @@ void Worker::checkpoint_load_failed(Id id, const std::string &error_msg) {
}
}
void Worker:: publish_data(Id id, const DataPtr &data, const std::string &checkpoint_path)
void Worker::publish_data(Id id, const DataPtr &data, const std::string &checkpoint_path)
{
logger->debug("Publishing data id={} size={} info={}", id, data->get_size(), data->get_info());
public_data[id] = data;
......@@ -528,6 +531,20 @@ void Worker::remove_task(TaskInstance &task, bool free_resources)
assert(0);
}
void Worker::load_checkpoint(Id id, const std::string &path) {
if (!file_exists(path.c_str())) {
std::stringstream s;
s << "File '" << path << "' does not exists";
std::string error = s.str();
logger->error("Cannot load checkpoint {}: {}", id, error);
checkpoint_load_failed(id, error);
return;
}
DataPtr data = std::make_shared<ExternFile>(path);
checkpoint_loaded(id, data);
publish_data(id, data, "");
}
void Worker::task_failed(TaskInstance &task, const std::string &error_msg)
{
logger->error("Task id={} failed: {}", task.get_id(), error_msg);
......@@ -658,6 +675,12 @@ void Worker::on_message(const char *data, size_t size)
}
break;
}
case comm::WorkerCommand_Type_LOAD_CHECKPOINT: {
logger->debug("Loading checkpoint id={} path={}", msg.id(), msg.checkpoint_path());
assert(msg.has_checkpoint_path());
load_checkpoint(msg.id(), msg.checkpoint_path());
break;
}
case comm::WorkerCommand_Type_DICTIONARY: {
auto count = msg.symbols_size();
logger->debug("New dictionary ({} symbols)", count);
......
......@@ -127,11 +127,15 @@ public:
return trace;
}
void on_dictionary_updated();
void load_checkpoint(base::Id id, const std::string &path);
void checkpoint_written(base::Id id);
void checkpoint_write_failed(base::Id id, const std::string &error_msg);
void checkpoint_loaded(base::Id id);
void checkpoint_loaded(base::Id id, const DataPtr &Data);
void checkpoint_load_failed(base::Id id, const std::string &error_msg);
private:
void register_worker();
......
......@@ -49,7 +49,7 @@ void ComputationState::plan_node(TaskNode &node, std::vector<TaskNode *> &to_loa
int remaining_inputs = 0;
for (TaskNode *input_node : node.get_inputs()) {
plan_node(*input_node);
plan_node(*input_node, to_load);
if (!input_node->is_computed()) {
remaining_inputs += 1;
input_node->add_next(&node);
......
......@@ -91,9 +91,9 @@ void Server::on_checkpoint_write_failed(loom::base::Id id, WorkerConnection *wc,
task_manager.on_checkpoint_write_failed(id, wc, error_msg);
}
void Server::on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc)
void Server::on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc, size_t size, size_t length)
{
task_manager.on_checkpoint_load_finished(id, wc);
task_manager.on_checkpoint_load_finished(id, wc, size, length);
}
void Server::on_checkpoint_load_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg)
......
......@@ -91,7 +91,7 @@ public:
void on_checkpoint_write_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
void on_checkpoint_write_finished(loom::base::Id id, WorkerConnection *wc);
void on_checkpoint_load_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
void on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc);
void on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc, size_t size, size_t length);
private:
......
......@@ -10,6 +10,7 @@
#include <assert.h>
#include <limits.h>
#include <stdlib.h>
#include <memory>
using namespace loom;
using namespace loom::base;
......@@ -24,7 +25,7 @@ loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan)
std::vector<TaskNode*> to_load;
loom::base::Id id_base = cstate.add_plan(plan, to_load);
for (TaskNode *node : to_load) {
random_worker()->load_checkpoint(node);
random_worker()->load_checkpoint(node->get_id(), node->get_task_def().checkpoint_path);
}
distribute_work(schedule(cstate));
return id_base;
......@@ -192,7 +193,7 @@ void TaskManager::on_checkpoint_write_finished(Id id, WorkerConnection *wc)
}
}
void TaskManager::on_checkpoint_load_finished(Id id, WorkerConnection *wc)
void TaskManager::on_checkpoint_load_finished(Id id, WorkerConnection *wc, size_t size, size_t length)
{
if (unlikely(wc->is_blocked())) {
wc->residual_task_finished(id, false, false);
......@@ -306,6 +307,8 @@ void TaskManager::release_node(TaskNode *node)
WorkerConnection *TaskManager::random_worker()
{
assert(!server.get_connections().empty());
TODO
auto &connections = server.get_connections();
assert(!connections.empty());
int index = rand() % connections.size();
return connections[index].get();
}
......@@ -30,7 +30,7 @@ public:
void on_task_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
void on_checkpoint_write_finished(loom::base::Id id, WorkerConnection *wc);
void on_checkpoint_write_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
void on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc);
void on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc, size_t size, size_t length);
void on_checkpoint_load_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg);
int get_n_of_data_objects() const {
......
......@@ -83,7 +83,7 @@ void WorkerConnection::on_message(const char *buffer, size_t size)
}
if (type == WorkerResponse_Type_CHECKPOINT_LOADED) {
server.on_checkpoint_load_finished(msg.id(), this);
server.on_checkpoint_load_finished(msg.id(), this, msg.size(), msg.length());
return;
}
......@@ -127,6 +127,18 @@ void WorkerConnection::send_data(Id id, const std::string &address)
send_message(*socket, msg);
}
void WorkerConnection::load_checkpoint(Id id, const std::string &checkpoint_path)
{
using namespace loom::pb::comm;
logger->debug("Command for {}: LOAD_CHECKPOINT id={} path={}", this->address, id, checkpoint_path);
WorkerCommand msg;
msg.set_type(WorkerCommand_Type_LOAD_CHECKPOINT);
msg.set_id(id);
msg.set_checkpoint_path(checkpoint_path);
send_message(*socket, msg);
}
void WorkerConnection::remove_data(Id id)
{
using namespace loom::pb::comm;
......
......@@ -95,6 +95,7 @@ public:
void residual_checkpoint_finished(loom::base::Id id);
void create_trace(const std::string &trace_path);
void load_checkpoint(loom::base::Id id, const std::string &checkpoint_path);
private:
Server &server;
......
......@@ -317,10 +317,16 @@ static std::vector<TaskNode*> nodes(ComputationState &s, std::vector<loom::base:
return result;
}
static void add_plan(ComputationState &s, const loom::pb::comm::Plan &plan) {
std::vector<TaskNode*> to_load;
s.add_plan(plan, to_load);
assert(to_load.empty());
}
TEST_CASE("basic-plan", "[scheduling]") {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(make_simple_plan(server));
add_plan(s, make_simple_plan(server));
auto w1 = simple_worker(server, "w1");
auto w2 = simple_worker(server, "w2");
......@@ -388,7 +394,7 @@ TEST_CASE("basic-plan", "[scheduling]") {
TEST_CASE("plan4", "[scheduling]") {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(make_plan4(server));
add_plan(s, make_plan4(server));
SECTION("More narrow") {
auto w1 = simple_worker(server, "w1", 1);
......@@ -473,7 +479,7 @@ TEST_CASE("plan4", "[scheduling]") {
TEST_CASE("Plan2", "[scheduling]") {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(make_plan2(server));
add_plan(s, make_plan2(server));
SECTION("Two simple workers") {
auto w1 = simple_worker(server, "w1");
......@@ -562,7 +568,7 @@ TEST_CASE("big-plan", "[scheduling]") {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(make_big_plan(server, BIG_PLAN_SIZE));
add_plan(s, make_big_plan(server, BIG_PLAN_SIZE));
std::vector<WorkerConnection*> ws;
ws.reserve(BIG_PLAN_WORKERS);
......@@ -593,7 +599,7 @@ TEST_CASE("big-simple-plan", "[scheduling]") {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(make_big_trivial_plan(server, BIG_PLAN_SIZE));
add_plan(s, make_big_trivial_plan(server, BIG_PLAN_SIZE));
std::vector<WorkerConnection*> ws;
ws.reserve(BIG_PLAN_WORKERS);
......@@ -620,7 +626,7 @@ TEST_CASE("big-simple-plan", "[scheduling]") {
TEST_CASE("request-plan", "[scheduling]") {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(make_request_plan(server));
add_plan(s, make_request_plan(server));
SECTION("0 cpus - include free tasks") {
auto w1 = simple_worker(server, "w1", 0);
......@@ -693,7 +699,7 @@ TEST_CASE("request-plan", "[scheduling]") {
TEST_CASE("continuation2", "[scheduling]") {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(make_plan2(server));
add_plan(s, make_plan2(server));
/*SECTION("Stick together") {
auto w1 = simple_worker(server, "w1", 2);
......@@ -719,7 +725,7 @@ TEST_CASE("continuation2", "[scheduling]") {
TEST_CASE("continuation", "[scheduling]") {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(make_plan3(server));
add_plan(s, make_plan3(server));
SECTION("Stick together - inputs dominant") {
auto w1 = simple_worker(server, "w1", 2);
......@@ -821,7 +827,7 @@ TEST_CASE("benchmark1", "[benchmark][!hide]") {
for (size_t n_workers = 10; n_workers < 600; n_workers *= 2) {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(plan);
add_plan(s, plan);
std::vector<WorkerConnection*> ws;
ws.reserve(n_workers);
......@@ -845,7 +851,6 @@ TEST_CASE("benchmark1", "[benchmark][!hide]") {
}
}
TEST_CASE("benchmark2", "[benchmark][!hide]") {
using namespace std::chrono;
const size_t CPUS = 24;
......@@ -857,7 +862,7 @@ TEST_CASE("benchmark2", "[benchmark][!hide]") {
for (size_t n_workers = 10; n_workers <= 160; n_workers *= 2) {
Server server(NULL, 0);
ComputationState s(server);
s.add_plan(plan);
add_plan(s, plan);
std::vector<WorkerConnection*> ws;
ws.reserve(n_workers);
for (size_t i = 0; i < n_workers; i++) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment