diff --git a/pb/comm.proto b/pb/comm.proto index c994818fc92038ce4b7b0b5edb1b10b340fd1372..ccf22849b144ddc9e3cc45ac6fab8838ecd619fa 100644 --- a/pb/comm.proto +++ b/pb/comm.proto @@ -63,6 +63,7 @@ message WorkerCommand { REMOVE = 3; DICTIONARY = 8; UPDATE = 9; + LOAD_CHECKPOINT = 10; } required Type type = 1; @@ -74,6 +75,8 @@ message WorkerCommand { optional string task_config = 4; repeated int32 task_inputs = 5; optional int32 n_cpus = 6; + + // TASK + LOAD_CHECKPOINT optional string checkpoint_path = 7; // SEND @@ -93,13 +96,15 @@ message WorkerResponse { FINISHED_AND_CHECKPOINTING = 2; TRANSFERED = 3; FAILED = 4; - CHECKPOINT_FINISHED = 5; - CHECKPOINT_FAILED = 6; + CHECKPOINT_WRITTEN = 5; + CHECKPOINT_WRITE_FAILED = 6; + CHECKPOINT_LOADED = 7; + CHECKPOINT_LOAD_FAILED = 8; } required Type type = 1; required int32 id = 2; - // FINISHED + // FINISHED + CHECKPOINT_LOADED optional uint64 size = 3; optional uint64 length = 4; diff --git a/src/libloom/fsutils.cpp b/src/libloom/fsutils.cpp index fe273b4a127393ea189e4e8b8d2e0417f40b2e7a..9840ca2b98b87e4758853a0389bef0878fd5544c 100644 --- a/src/libloom/fsutils.cpp +++ b/src/libloom/fsutils.cpp @@ -61,3 +61,9 @@ int loom::base::make_path(const char *path, mode_t mode) } return 0; } + +bool loom::base::file_exists(const char *path) +{ + struct stat buffer; + return (stat(path, &buffer) == 0); +} diff --git a/src/libloom/fsutils.h b/src/libloom/fsutils.h index c878c374ce40289e3ced862eb8575f39414591d8..49bbd4fe9e01203dab25774c8cbfdbf4d3ad467f 100644 --- a/src/libloom/fsutils.h +++ b/src/libloom/fsutils.h @@ -12,6 +12,7 @@ namespace base { int make_path(const char *path, mode_t mode); size_t file_size(const char *path); +bool file_exists(const char *path); } } diff --git a/src/libloomw/checkpointwriter.cpp b/src/libloomw/checkpointwriter.cpp index fb4c943d0c0274b0c7ce39ec19e00b128ddc003a..f369f07d3a2827cfff0234e5816790b60caddcc5 100644 --- a/src/libloomw/checkpointwriter.cpp +++ b/src/libloomw/checkpointwriter.cpp @@ -27,6 +27,14 @@ void loom::CheckPointWriter::_work_cb(uv_work_t *req) writer->error = "Writing checkpoint '" + path + "' failed. Cannot create " + tmp_path + ":" + strerror(errno); return; } + + const char *ptr = writer->data->get_raw_data(); + + if (!ptr) { + writer->error = "Data '" + writer->data->get_info() + "' cannot be checkpointed"; + return; + } + fout.write(writer->data->get_raw_data(), writer->data->get_size()); fout.close(); @@ -43,7 +51,7 @@ void loom::CheckPointWriter::_after_work_cb(uv_work_t *req, int status) if (writer->error.empty()) { writer->worker.checkpoint_written(writer->id); } else { - writer->worker.checkpoint_failed(writer->id, writer->error); + writer->worker.checkpoint_write_failed(writer->id, writer->error); } delete writer; } diff --git a/src/libloomw/worker.cpp b/src/libloomw/worker.cpp index 90cb940f4f872de59038882730eaf7cf074fd61d..f1140ebe53c5dd3dc495d3b319f3e6a30266617f 100644 --- a/src/libloomw/worker.cpp +++ b/src/libloomw/worker.cpp @@ -21,6 +21,7 @@ #include "libloom/sendbuffer.h" #include "libloom/pbutils.h" #include "libloom/fsutils.h" +#include "data/externfile.h" #include "loom_define.h" #include "checkpointwriter.h" @@ -288,24 +289,47 @@ void Worker::checkpoint_written(Id id) { logger->debug("Checkpoint written id={}", id); if (server_conn.is_connected()) { loom::pb::comm::WorkerResponse msg; - msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_FINISHED); + msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_WRITTEN); msg.set_id(id); send_message(server_conn, msg); } } -void Worker::checkpoint_failed(Id id, const std::string &error_msg) { +void Worker::checkpoint_write_failed(Id id, const std::string &error_msg) { logger->debug("Cannot write checkpoint id={}, error={}", id, error_msg); if (server_conn.is_connected()) { loom::pb::comm::WorkerResponse msg; - msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_FAILED); + msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_WRITE_FAILED); msg.set_id(id); msg.set_error_msg(error_msg); send_message(server_conn, msg); } } -void Worker:: publish_data(Id id, const DataPtr &data, const std::string &checkpoint_path) +void Worker::checkpoint_loaded(Id id, const DataPtr &data) { + logger->debug("Checkpoint loaded id={}", id); + if (server_conn.is_connected()) { + loom::pb::comm::WorkerResponse msg; + msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_LOADED); + msg.set_id(id); + msg.set_size(data->get_size()); + msg.set_length(data->get_length()); + send_message(server_conn, msg); + } +} + +void Worker::checkpoint_load_failed(Id id, const std::string &error_msg) { + logger->debug("Cannot load checkpoint id={}, error={}", id, error_msg); + if (server_conn.is_connected()) { + loom::pb::comm::WorkerResponse msg; + msg.set_type(loom::pb::comm::WorkerResponse_Type_CHECKPOINT_LOAD_FAILED); + msg.set_id(id); + msg.set_error_msg(error_msg); + send_message(server_conn, msg); + } +} + +void Worker::publish_data(Id id, const DataPtr &data, const std::string &checkpoint_path) { logger->debug("Publishing data id={} size={} info={}", id, data->get_size(), data->get_info()); public_data[id] = data; @@ -507,6 +531,20 @@ void Worker::remove_task(TaskInstance &task, bool free_resources) assert(0); } +void Worker::load_checkpoint(Id id, const std::string &path) { + if (!file_exists(path.c_str())) { + std::stringstream s; + s << "File '" << path << "' does not exists"; + std::string error = s.str(); + logger->error("Cannot load checkpoint {}: {}", id, error); + checkpoint_load_failed(id, error); + return; + } + DataPtr data = std::make_shared<ExternFile>(path); + checkpoint_loaded(id, data); + publish_data(id, data, ""); +} + void Worker::task_failed(TaskInstance &task, const std::string &error_msg) { logger->error("Task id={} failed: {}", task.get_id(), error_msg); @@ -637,6 +675,12 @@ void Worker::on_message(const char *data, size_t size) } break; } + case comm::WorkerCommand_Type_LOAD_CHECKPOINT: { + logger->debug("Loading checkpoint id={} path={}", msg.id(), msg.checkpoint_path()); + assert(msg.has_checkpoint_path()); + load_checkpoint(msg.id(), msg.checkpoint_path()); + break; + } case comm::WorkerCommand_Type_DICTIONARY: { auto count = msg.symbols_size(); logger->debug("New dictionary ({} symbols)", count); diff --git a/src/libloomw/worker.h b/src/libloomw/worker.h index abbbde22e7b8d2d66d107b4eea8bffc3063828d9..bf8e5422342103d3b8cacf5d0cd695e4b2b338e1 100644 --- a/src/libloomw/worker.h +++ b/src/libloomw/worker.h @@ -127,10 +127,16 @@ public: return trace; } + + void on_dictionary_updated(); + + void load_checkpoint(base::Id id, const std::string &path); void checkpoint_written(base::Id id); - void checkpoint_failed(base::Id id, const std::string &error_msg); + void checkpoint_write_failed(base::Id id, const std::string &error_msg); + void checkpoint_loaded(base::Id id, const DataPtr &Data); + void checkpoint_load_failed(base::Id id, const std::string &error_msg); private: void register_worker(); void create_trace(const std::string &trace_path, loom::base::Id worker_id); diff --git a/src/server/compstate.cpp b/src/server/compstate.cpp index 40a7e6961e70ebe39a9d5ada98ca0abab5af5145..29909354c9e9373b0c8f4320f123e6191d487722 100644 --- a/src/server/compstate.cpp +++ b/src/server/compstate.cpp @@ -5,6 +5,7 @@ #include "pb/comm.pb.h" #include "libloom/log.h" +#include "libloom/fsutils.h" constexpr static double TRANSFER_COST_COEF = 1.0 / (1024 * 1024); // 1MB = 1cost @@ -22,18 +23,45 @@ ComputationState::ComputationState(Server &server) : server(server) void ComputationState::add_node(std::unique_ptr<TaskNode> &&node) { auto id = node->get_id(); + /* for (TaskNode* input_node : node->get_inputs()) { input_node->add_next(node.get()); } if (node->is_ready()) { pending_nodes.insert(node.get()); - } + }*/ auto result = nodes.insert(std::make_pair(id, std::move(node))); assert(result.second); // Check that ID is fresh } +void ComputationState::plan_node(TaskNode &node, std::vector<TaskNode *> &to_load) { + if (node.is_planned()) { + return; + } + node.set_planned(); + + if (!node.get_task_def().checkpoint_path.empty() && loom::base::file_exists(node.get_task_def().checkpoint_path.c_str())) { + node.set_checkpoint(); + to_load.push_back(&node); + return; + } + + int remaining_inputs = 0; + for (TaskNode *input_node : node.get_inputs()) { + plan_node(*input_node, to_load); + if (!input_node->is_computed()) { + remaining_inputs += 1; + input_node->add_next(&node); + } + } + node.set_remaining_inputs(remaining_inputs); + if (remaining_inputs == 0) { + pending_nodes.insert(&node); + } +} + void ComputationState::reserve_new_nodes(size_t size) { nodes.reserve(nodes.size() + size); @@ -241,7 +269,7 @@ void ComputationState::make_expansion(std::vector<std::string> &configs, } }*/ -loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan) +loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, std::vector<TaskNode*> &to_load) { auto task_size = plan.tasks_size(); assert(plan.has_id_base()); @@ -268,7 +296,9 @@ loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan) def.task_type = pt.task_type(); def.config = pt.config(); def.checkpoint_path = pt.checkpoint_path(); + bool is_result = false; if (pt.has_result() && pt.result()) { + is_result = true; def.flags.set(static_cast<size_t>(TaskDefFlags::RESULT)); } @@ -284,7 +314,12 @@ loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan) n_cpus = resources[pt.resource_request_index()]; } def.n_cpus = n_cpus; - add_node(std::make_unique<TaskNode>(id, std::move(def))); + + auto new_node = std::make_unique<TaskNode>(id, std::move(def)); + if (is_result) { + plan_node(*new_node.get(), to_load); + } + add_node(std::move(new_node)); } return id_base; } diff --git a/src/server/compstate.h b/src/server/compstate.h index c930e18c42745e8afa0b702e30cd966a7ad52054..5220a59895d2bee39fcbaa020fb78ca14e3b6dc4 100644 --- a/src/server/compstate.h +++ b/src/server/compstate.h @@ -41,11 +41,9 @@ public: void remove_node(TaskNode &node); - bool is_ready(const TaskNode &node); - int get_n_data_objects() const; - loom::base::Id add_plan(const loom::pb::comm::Plan &plan); + loom::base::Id add_plan(const loom::pb::comm::Plan &plan, std::vector<TaskNode *> &to_load); void test_ready_nodes(std::vector<loom::base::Id> ids); loom::base::Id pop_result_client_id(loom::base::Id id); @@ -62,7 +60,8 @@ public: std::unique_ptr<TaskNode> pop_node(loom::base::Id id); void clear_all(); - void add_pending_node(TaskNode &node); + void add_pending_node(TaskNode &node); + void plan_node(TaskNode &node, std::vector<TaskNode*> &to_load); private: std::unordered_map<loom::base::Id, std::unique_ptr<TaskNode>> nodes; diff --git a/src/server/server.cpp b/src/server/server.cpp index 432a2773ea3c51992d57a79f2ba91c1668f6845c..9bf63f69b84371445f4a7bd1f369df8d4971e616 100644 --- a/src/server/server.cpp +++ b/src/server/server.cpp @@ -81,14 +81,24 @@ void Server::remove_freshconnection(FreshConnection &conn) fresh_connections.erase(i); } -void Server::on_checkpoint_finished(loom::base::Id id, WorkerConnection *wc) +void Server::on_checkpoint_write_finished(loom::base::Id id, WorkerConnection *wc) { - task_manager.on_checkpoint_finished(id, wc); + task_manager.on_checkpoint_write_finished(id, wc); } -void Server::on_checkpoint_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg) +void Server::on_checkpoint_write_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg) { - task_manager.on_checkpoint_failed(id, wc, error_msg); + task_manager.on_checkpoint_write_failed(id, wc, error_msg); +} + +void Server::on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc, size_t size, size_t length) +{ + task_manager.on_checkpoint_load_finished(id, wc, size, length); +} + +void Server::on_checkpoint_load_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg) +{ + task_manager.on_checkpoint_load_failed(id, wc, error_msg); } void Server::on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing) diff --git a/src/server/server.h b/src/server/server.h index 81417840f00ce6193394e9bfeab171047e378165..3382500548e6decaaec8143baae4eaa6d6b1ddef 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -88,9 +88,10 @@ public: return trace; } - void on_checkpoint_finished(loom::base::Id id, WorkerConnection *wc); - void on_checkpoint_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg); - + void on_checkpoint_write_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg); + void on_checkpoint_write_finished(loom::base::Id id, WorkerConnection *wc); + void on_checkpoint_load_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg); + void on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc, size_t size, size_t length); private: diff --git a/src/server/taskmanager.cpp b/src/server/taskmanager.cpp index 0ed438d1b0b26df7a3b157bf18f81d3b83465e9f..aaa0e953367f443543942ef401f2dd7831f5c5b1 100644 --- a/src/server/taskmanager.cpp +++ b/src/server/taskmanager.cpp @@ -9,6 +9,8 @@ #include <algorithm> #include <assert.h> #include <limits.h> +#include <stdlib.h> +#include <memory> using namespace loom; using namespace loom::base; @@ -20,7 +22,13 @@ TaskManager::TaskManager(Server &server) loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan) { - loom::base::Id id_base = cstate.add_plan(plan); + std::vector<TaskNode*> to_load; + loom::base::Id id_base = cstate.add_plan(plan, to_load); + for (TaskNode *node : to_load) { + WorkerConnection *wc = random_worker(); + node->set_as_loading(wc); + wc->load_checkpoint(node->get_id(), node->get_task_def().checkpoint_path); + } distribute_work(schedule(cstate)); return id_base; } @@ -76,7 +84,8 @@ void TaskManager::remove_node(TaskNode &node) assert(status == TaskStatus::OWNER); wc->remove_data(id); }); - cstate.remove_node(node); + node.set_not_needed(); + //cstate.remove_node(node); } void TaskManager::on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing) @@ -108,7 +117,7 @@ void TaskManager::on_task_finished(loom::base::Id id, size_t size, size_t length } if (checkpointing) { - wc->change_running_checkpoints(1); + wc->change_checkpoint_writes(1); } for (TaskNode *input_node : node.get_inputs()) { @@ -123,9 +132,9 @@ void TaskManager::on_task_finished(loom::base::Id id, size_t size, size_t length cstate.add_pending_node(*nn); } } - } /*else { - remove_node(*node); - }*/ + } else if (!node.is_result()) { + remove_node(node); + } if (cstate.has_pending_nodes()) { server.need_task_distribution(); @@ -166,14 +175,14 @@ void TaskManager::on_task_failed(Id id, WorkerConnection *wc, const std::string } } -void TaskManager::on_checkpoint_finished(Id id, WorkerConnection *wc) +void TaskManager::on_checkpoint_write_finished(Id id, WorkerConnection *wc) { if (unlikely(wc->is_blocked())) { wc->residual_checkpoint_finished(id); return; } logger->debug("Checkpoint id={} finished on worker {}", id, wc->get_address()); - wc->change_running_checkpoints(-1); + wc->change_checkpoint_writes(-1); TaskNode &node = cstate.get_node(id); assert(node.has_defined_checkpoint()); node.set_checkpoint(); @@ -186,14 +195,50 @@ void TaskManager::on_checkpoint_finished(Id id, WorkerConnection *wc) } } -void TaskManager::on_checkpoint_failed(Id id, WorkerConnection *wc, const std::string &error_msg) +void TaskManager::on_checkpoint_load_finished(Id id, WorkerConnection *wc, size_t size, size_t length) { + if (unlikely(wc->is_blocked())) { + wc->residual_task_finished(id, false, false); + return; + } + wc->change_checkpoint_loads(-1); + + TaskNode &node = cstate.get_node(id); + node.set_as_loaded(wc, size, length); + + if (node.is_result()) { + logger->debug("Task id={} [RESULT] checkpoint loaded", id); + + ClientConnection *cc = server.get_client_connection(); + if (cc) { + cc->send_info_about_finished_result(node); + } + } else { + logger->debug("Task id={} checkpoint loaded", id); + } + + if (!node.get_nexts().empty()) { + for (TaskNode *nn : node.get_nexts()) { + if (nn->input_is_ready(&node)) { + cstate.add_pending_node(*nn); + } + } + } else if (!node.is_result()) { + remove_node(node); + } + if (cstate.has_pending_nodes()) { + server.need_task_distribution(); + } +} + +void TaskManager::on_checkpoint_write_failed(Id id, WorkerConnection *wc, const std::string &error_msg) +{ + wc->change_checkpoint_writes(-1); if (unlikely(wc->is_blocked())) { return; } logger->error("Checkpoint id={} failed on worker {}: {}", id, wc->get_address(), error_msg); - wc->change_running_checkpoints(-1); auto cc = server.get_client_connection(); if (cc) { cc->send_task_failed(id, *wc, error_msg); @@ -201,6 +246,21 @@ void TaskManager::on_checkpoint_failed(Id id, WorkerConnection *wc, const std::s trash_all_tasks(); } +void TaskManager::on_checkpoint_load_failed(Id id, WorkerConnection *wc, const std::string &error_msg) +{ + if (unlikely(wc->is_blocked())) { + wc->residual_task_finished(id, false, false); + return; + } + wc->change_checkpoint_loads(-1); + logger->error("Checkpoint id={} load failed on worker {}: {}", + id, wc->get_address(), error_msg); + auto cc = server.get_client_connection(); + if (cc) { + cc->send_task_failed(id, *wc, error_msg); + } +} + void TaskManager::run_task_distribution() { uv_loop_t *loop = server.get_loop(); @@ -237,6 +297,10 @@ void TaskManager::run_task_distribution() void TaskManager::trash_all_tasks() { + for (auto &wc : server.get_connections()) { + wc->change_residual_tasks(wc->get_checkpoint_loads()); + wc->change_checkpoint_loads(-wc->get_checkpoint_loads()); + } cstate.foreach_node([](std::unique_ptr<TaskNode> &task) { task->foreach_worker([&task](WorkerConnection *wc, TaskStatus status) { if (status == TaskStatus::OWNER) { @@ -266,3 +330,11 @@ void TaskManager::release_node(TaskNode *node) node->reset_result_flag(); } } + +WorkerConnection *TaskManager::random_worker() +{ + auto &connections = server.get_connections(); + assert(!connections.empty()); + int index = rand() % connections.size(); + return connections[index].get(); +} diff --git a/src/server/taskmanager.h b/src/server/taskmanager.h index 153712381d7d2def6543b80b5c93cc6349962116..9b34eb23729fe5f1766dec95c60c7bba8a0ff2ec 100644 --- a/src/server/taskmanager.h +++ b/src/server/taskmanager.h @@ -28,9 +28,10 @@ public: void on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing); void on_data_transferred(loom::base::Id id, WorkerConnection *wc); void on_task_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg); - void on_checkpoint_finished(loom::base::Id id, WorkerConnection *wc); - void on_checkpoint_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg); - + void on_checkpoint_write_finished(loom::base::Id id, WorkerConnection *wc); + void on_checkpoint_write_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg); + void on_checkpoint_load_finished(loom::base::Id id, WorkerConnection *wc, size_t size, size_t length); + void on_checkpoint_load_failed(loom::base::Id id, WorkerConnection *wc, const std::string &error_msg); int get_n_of_data_objects() const { return cstate.get_n_data_objects(); @@ -45,6 +46,7 @@ public: void trash_all_tasks(); void release_node(TaskNode *node); + WorkerConnection *random_worker(); private: Server &server; diff --git a/src/server/tasknode.cpp b/src/server/tasknode.cpp index f29aba63d7de35ecf89ba881872a762a98cefede..3214c88cc2e4cc8372a8bd5206b66297810778bb 100644 --- a/src/server/tasknode.cpp +++ b/src/server/tasknode.cpp @@ -12,11 +12,7 @@ TaskNode::TaskNode(loom::base::Id id, TaskDef &&task) length(0), remaining_inputs(0) { - for (TaskNode *t : this->task.inputs) { - if (!t->is_computed()) { - remaining_inputs += 1; - } - } + } void TaskNode::reset_result_flag() @@ -81,6 +77,18 @@ void TaskNode::set_as_finished(WorkerConnection *wc, size_t size, size_t length) flags.set(static_cast<size_t>(TaskNodeFlags::FINISHED)); } +void TaskNode::set_as_loaded(WorkerConnection *wc, size_t size, size_t length) { + assert(get_worker_status(wc) == TaskStatus::LOADING); + set_worker_status(wc, TaskStatus::OWNER); + this->size = size; + this->length = length; + flags.set(static_cast<size_t>(TaskNodeFlags::FINISHED)); +} + +void TaskNode::set_as_loading(WorkerConnection *wc) { + set_worker_status(wc, TaskStatus::LOADING); +} + void TaskNode::set_as_finished_no_check(WorkerConnection *wc, size_t size, size_t length) { set_worker_status(wc, TaskStatus::OWNER); diff --git a/src/server/tasknode.h b/src/server/tasknode.h index 36d01c01d4a96652cbd69d32cb3b0c6de1f1bdd6..6d89839681e0141686a99fceaf8d32e2b285ecd2 100644 --- a/src/server/tasknode.h +++ b/src/server/tasknode.h @@ -20,7 +20,9 @@ enum class TaskDefFlags : size_t { enum class TaskNodeFlags : size_t { FINISHED, - CHECKPOINT + CHECKPOINT, + PLANNED, + FLAGS_COUNT }; struct TaskDef @@ -37,6 +39,7 @@ enum class TaskStatus { NONE, RUNNING, TRANSFER, + LOADING, OWNER, }; @@ -74,6 +77,14 @@ public: flags.set(static_cast<size_t>(TaskNodeFlags::CHECKPOINT)); } + bool is_planned() const { + return flags.test(static_cast<size_t>(TaskNodeFlags::PLANNED)); + } + + void set_planned() { + flags.set(static_cast<size_t>(TaskNodeFlags::PLANNED)); + } + void reset_result_flag(); inline size_t get_size() const { @@ -119,6 +130,14 @@ public: workers[wc] = status; } + void set_remaining_inputs(int value) { + remaining_inputs = value; + } + + int get_remaining_inputs() const { + return remaining_inputs; + } + inline bool input_is_ready(TaskNode *node) { assert(remaining_inputs > 0); return --remaining_inputs == 0; @@ -153,12 +172,18 @@ public: bool next_finished(TaskNode &); void set_as_finished(WorkerConnection *wc, size_t size, size_t length); + void set_as_loaded(WorkerConnection *wc, size_t size, size_t length); void set_as_running(WorkerConnection *wc); + void set_as_loading(WorkerConnection *wc); void set_as_transferred(WorkerConnection *wc); void set_as_none(WorkerConnection *wc); // For unit testing void set_as_finished_no_check(WorkerConnection *wc, size_t size, size_t length); + void set_not_needed() { + flags.reset(static_cast<size_t>(TaskNodeFlags::PLANNED)); + flags.reset(static_cast<size_t>(TaskNodeFlags::FINISHED)); + } std::string debug_str() const; @@ -170,7 +195,7 @@ private: std::unordered_multiset<TaskNode*> nexts; // Runtime info - std::bitset<2> flags; + std::bitset<3> flags; WorkerMap<TaskStatus> workers; size_t size; size_t length; diff --git a/src/server/workerconn.cpp b/src/server/workerconn.cpp index fbf4d8b4599e49d78b010d17c91b76bb0e3dc65f..e47b2bbb5412d49d933b3743c2fbe86baeb4df88 100644 --- a/src/server/workerconn.cpp +++ b/src/server/workerconn.cpp @@ -24,7 +24,8 @@ WorkerConnection::WorkerConnection(Server &server, data_types(data_types), worker_id(worker_id), n_residual_tasks(0), - running_checkpoints(0) + checkpoint_writes(0), + checkpoint_loads(0) { logger->info("Worker {} connected (cpus={})", address, resource_cpus); if (this->socket) { @@ -71,13 +72,23 @@ void WorkerConnection::on_message(const char *buffer, size_t size) return; } - if (type == WorkerResponse_Type_CHECKPOINT_FINISHED) { - server.on_checkpoint_finished(msg.id(), this); + if (type == WorkerResponse_Type_CHECKPOINT_WRITTEN) { + server.on_checkpoint_write_finished(msg.id(), this); return; } - if (type == WorkerResponse_Type_CHECKPOINT_FAILED) { - server.on_checkpoint_failed(msg.id(), this, msg.error_msg()); + if (type == WorkerResponse_Type_CHECKPOINT_WRITE_FAILED) { + server.on_checkpoint_write_failed(msg.id(), this, msg.error_msg()); + return; + } + + if (type == WorkerResponse_Type_CHECKPOINT_LOADED) { + server.on_checkpoint_load_finished(msg.id(), this, msg.size(), msg.length()); + return; + } + + if (type == WorkerResponse_Type_CHECKPOINT_LOAD_FAILED) { + server.on_checkpoint_load_failed(msg.id(), this, msg.error_msg()); return; } } @@ -116,6 +127,19 @@ void WorkerConnection::send_data(Id id, const std::string &address) send_message(*socket, msg); } +void WorkerConnection::load_checkpoint(Id id, const std::string &checkpoint_path) +{ + checkpoint_loads += 1; + using namespace loom::pb::comm; + logger->debug("Command for {}: LOAD_CHECKPOINT id={} path={}", this->address, id, checkpoint_path); + + WorkerCommand msg; + msg.set_type(WorkerCommand_Type_LOAD_CHECKPOINT); + msg.set_id(id); + msg.set_checkpoint_path(checkpoint_path); + send_message(*socket, msg); +} + void WorkerConnection::remove_data(Id id) { using namespace loom::pb::comm; @@ -137,7 +161,7 @@ void WorkerConnection::residual_task_finished(Id id, bool success, bool checkpoi change_residual_tasks(-1); if (checkpointing) { - running_checkpoints += 1; + checkpoint_writes += 1; } if (success) { @@ -151,7 +175,7 @@ void WorkerConnection::residual_task_finished(Id id, bool success, bool checkpoi void WorkerConnection::residual_checkpoint_finished(Id id) { logger->debug("Residual checkpoint id={} finished on {}", id, address); - running_checkpoints -= 1; + checkpoint_writes -= 1; if (!is_blocked()) { server.need_task_distribution(); } diff --git a/src/server/workerconn.h b/src/server/workerconn.h index e3221d3e1ea73b221a205e5536bb81ddb7c7b608..1da0875644558ad73c2f8e9eb875d3c1fdc4c5d7 100644 --- a/src/server/workerconn.h +++ b/src/server/workerconn.h @@ -70,15 +70,23 @@ public: } bool is_blocked() const { - return n_residual_tasks > 0 && running_checkpoints > 0; + return n_residual_tasks > 0 && checkpoint_writes > 0; } void change_residual_tasks(int value) { n_residual_tasks += value; } - void change_running_checkpoints(int value) { - n_residual_checkpoints += value; + int get_checkpoint_loads() const { + return checkpoint_loads; + } + + void change_checkpoint_writes(int value) { + checkpoint_writes += value; + } + + void change_checkpoint_loads(int value) { + checkpoint_loads += value; } void free_resources(TaskNode &node); @@ -87,6 +95,7 @@ public: void residual_checkpoint_finished(loom::base::Id id); void create_trace(const std::string &trace_path); + void load_checkpoint(loom::base::Id id, const std::string &checkpoint_path); private: Server &server; @@ -101,7 +110,8 @@ private: int worker_id; int n_residual_tasks; int n_residual_checkpoints; - int running_checkpoints; + int checkpoint_writes; + int checkpoint_loads; int scheduler_index; int scheduler_free_cpus; diff --git a/tests/client/test_checkpoint.py b/tests/client/test_checkpoint.py index ae748f078b8a7855a87c5e4ec222ddc1eb4ed782..71299aa9fbde0467eaa7c0dbf63e13dba6735e99 100644 --- a/tests/client/test_checkpoint.py +++ b/tests/client/test_checkpoint.py @@ -17,3 +17,38 @@ def test_checkpoint_basic(loom_env): with open(path, "rb") as f: assert f.read() == b"abcdXYZ" assert not os.path.isfile(path + ".loom.tmp") + + +def test_checkpoint_load(loom_env): + loom_env.start(1) + + path1 = os.path.join(LOOM_TEST_BUILD_DIR, "f1.txt") + path2 = os.path.join(LOOM_TEST_BUILD_DIR, "f2.txt") + path3 = os.path.join(LOOM_TEST_BUILD_DIR, "f3.txt") + path4 = os.path.join(LOOM_TEST_BUILD_DIR, "f4.txt") + path5 = os.path.join(LOOM_TEST_BUILD_DIR, "nonexisting") + + for i, p in enumerate((path1, path2, path3, path4)): + with open(p, "w") as f: + f.write("[{}]".format(i + 1)) + + t1 = tasks.const("$t1$") + t1.checkpoint_path = path1 # This shoud load: [1] + + t2 = tasks.const("$t2$") + t2.checkpoint_path = path2 # This shoud load: [2] + + t3 = tasks.const("$t3$") + t4 = tasks.const("$t4$") + + x1 = tasks.merge((t1, t2, t3)) # [1][2]$t3$ + x2 = tasks.merge((t1, x1)) + x2.checkpoint_path = path3 # loaded as [3] + + x3 = tasks.merge((t4, t4)) + x3.checkpoint_path = path4 # loaded as [4] + + x4 = tasks.merge((x3, x1, x2, t1, t2, t3)) + x4.checkpoint_path = path5 + + assert loom_env.submit_and_gather(x4) == b'[4][1][2]$t3$[3][1][2]$t3$' \ No newline at end of file diff --git a/tests/cpp/test_scheduler.cpp b/tests/cpp/test_scheduler.cpp index be296c7ebc50f95a3d6190d1483bf17ff6413034..1f79e97ea863130923ff30676c14f8c5807932a5 100644 --- a/tests/cpp/test_scheduler.cpp +++ b/tests/cpp/test_scheduler.cpp @@ -317,10 +317,16 @@ static std::vector<TaskNode*> nodes(ComputationState &s, std::vector<loom::base: return result; } +static void add_plan(ComputationState &s, const loom::pb::comm::Plan &plan) { + std::vector<TaskNode*> to_load; + s.add_plan(plan, to_load); + assert(to_load.empty()); +} + TEST_CASE("basic-plan", "[scheduling]") { Server server(NULL, 0); ComputationState s(server); - s.add_plan(make_simple_plan(server)); + add_plan(s, make_simple_plan(server)); auto w1 = simple_worker(server, "w1"); auto w2 = simple_worker(server, "w2"); @@ -388,7 +394,7 @@ TEST_CASE("basic-plan", "[scheduling]") { TEST_CASE("plan4", "[scheduling]") { Server server(NULL, 0); ComputationState s(server); - s.add_plan(make_plan4(server)); + add_plan(s, make_plan4(server)); SECTION("More narrow") { auto w1 = simple_worker(server, "w1", 1); @@ -473,7 +479,7 @@ TEST_CASE("plan4", "[scheduling]") { TEST_CASE("Plan2", "[scheduling]") { Server server(NULL, 0); ComputationState s(server); - s.add_plan(make_plan2(server)); + add_plan(s, make_plan2(server)); SECTION("Two simple workers") { auto w1 = simple_worker(server, "w1"); @@ -562,7 +568,7 @@ TEST_CASE("big-plan", "[scheduling]") { Server server(NULL, 0); ComputationState s(server); - s.add_plan(make_big_plan(server, BIG_PLAN_SIZE)); + add_plan(s, make_big_plan(server, BIG_PLAN_SIZE)); std::vector<WorkerConnection*> ws; ws.reserve(BIG_PLAN_WORKERS); @@ -593,7 +599,7 @@ TEST_CASE("big-simple-plan", "[scheduling]") { Server server(NULL, 0); ComputationState s(server); - s.add_plan(make_big_trivial_plan(server, BIG_PLAN_SIZE)); + add_plan(s, make_big_trivial_plan(server, BIG_PLAN_SIZE)); std::vector<WorkerConnection*> ws; ws.reserve(BIG_PLAN_WORKERS); @@ -620,7 +626,7 @@ TEST_CASE("big-simple-plan", "[scheduling]") { TEST_CASE("request-plan", "[scheduling]") { Server server(NULL, 0); ComputationState s(server); - s.add_plan(make_request_plan(server)); + add_plan(s, make_request_plan(server)); SECTION("0 cpus - include free tasks") { auto w1 = simple_worker(server, "w1", 0); @@ -693,7 +699,7 @@ TEST_CASE("request-plan", "[scheduling]") { TEST_CASE("continuation2", "[scheduling]") { Server server(NULL, 0); ComputationState s(server); - s.add_plan(make_plan2(server)); + add_plan(s, make_plan2(server)); /*SECTION("Stick together") { auto w1 = simple_worker(server, "w1", 2); @@ -719,7 +725,7 @@ TEST_CASE("continuation2", "[scheduling]") { TEST_CASE("continuation", "[scheduling]") { Server server(NULL, 0); ComputationState s(server); - s.add_plan(make_plan3(server)); + add_plan(s, make_plan3(server)); SECTION("Stick together - inputs dominant") { auto w1 = simple_worker(server, "w1", 2); @@ -821,7 +827,7 @@ TEST_CASE("benchmark1", "[benchmark][!hide]") { for (size_t n_workers = 10; n_workers < 600; n_workers *= 2) { Server server(NULL, 0); ComputationState s(server); - s.add_plan(plan); + add_plan(s, plan); std::vector<WorkerConnection*> ws; ws.reserve(n_workers); @@ -845,7 +851,6 @@ TEST_CASE("benchmark1", "[benchmark][!hide]") { } } - TEST_CASE("benchmark2", "[benchmark][!hide]") { using namespace std::chrono; const size_t CPUS = 24; @@ -857,7 +862,7 @@ TEST_CASE("benchmark2", "[benchmark][!hide]") { for (size_t n_workers = 10; n_workers <= 160; n_workers *= 2) { Server server(NULL, 0); ComputationState s(server); - s.add_plan(plan); + add_plan(s, plan); std::vector<WorkerConnection*> ws; ws.reserve(n_workers); for (size_t i = 0; i < n_workers; i++) {