diff --git a/pb/comm.proto b/pb/comm.proto index ccf22849b144ddc9e3cc45ac6fab8838ecd619fa..ea1f2fd1712867d776e62d4ddd2d53962705274e 100644 --- a/pb/comm.proto +++ b/pb/comm.proto @@ -164,6 +164,7 @@ message ClientRequest { // PLAN optional Plan plan = 2; + optional bool load_checkpoints = 4; // FETCH + RELEASE optional int32 id = 3; diff --git a/python/loom/client/client.py b/python/loom/client/client.py index 5729db487470aaeaa90228eef3a912752a7d4ec6..7d0e7dad034d92536f5c837be2bbc40259b1328b 100644 --- a/python/loom/client/client.py +++ b/python/loom/client/client.py @@ -241,7 +241,7 @@ class Client(object): print(t) assert 0 - def submit_one(self, task): + def submit_one(self, task, load=False): """Submits a task to the server and returns a future Args: @@ -256,9 +256,9 @@ class Client(object): >>> result = client.submit(task3) >>> print(result.gather()) """ - return self.submit((task,))[0] + return self.submit((task,), load=load)[0] - def submit(self, tasks): + def submit(self, tasks, load=False): """Submits tasks to the server and returns list of futures Args: @@ -294,7 +294,7 @@ class Client(object): msg = ClientRequest() msg.type = ClientRequest.PLAN - + msg.load_checkpoints = load include_metadata = self.trace_path is not None msg.plan.id_base = id_base plan.set_message( diff --git a/src/server/clientconn.cpp b/src/server/clientconn.cpp index 22c9fcbf992bee86ba4c058319ed8b3ace981416..a876e5064745c840722db6da4fc2252fffdb6d0f 100644 --- a/src/server/clientconn.cpp +++ b/src/server/clientconn.cpp @@ -56,8 +56,8 @@ void ClientConnection::on_message(const char *buffer, size_t size) case ClientRequest_Type_PLAN: { logger->debug("Plan received"); const Plan &plan = request.plan(); - loom::base::Id id_base = task_manager.add_plan(plan); - logger->info("Plan submitted tasks={}", plan.tasks_size()); + loom::base::Id id_base = task_manager.add_plan(plan, request.load_checkpoints()); + logger->info("Plan submitted tasks={}, load_checkpoints={}", plan.tasks_size(), request.load_checkpoints()); if (server.get_trace()) { server.create_file_in_trace_dir(std::to_string(id_base) + ".plan", buffer, size); diff --git a/src/server/compstate.cpp b/src/server/compstate.cpp index 29909354c9e9373b0c8f4320f123e6191d487722..cc2084457a9c42c8fd02734238200fcddb0ff509 100644 --- a/src/server/compstate.cpp +++ b/src/server/compstate.cpp @@ -23,26 +23,17 @@ ComputationState::ComputationState(Server &server) : server(server) void ComputationState::add_node(std::unique_ptr<TaskNode> &&node) { auto id = node->get_id(); - /* - for (TaskNode* input_node : node->get_inputs()) { - input_node->add_next(node.get()); - } - - if (node->is_ready()) { - pending_nodes.insert(node.get()); - }*/ - auto result = nodes.insert(std::make_pair(id, std::move(node))); assert(result.second); // Check that ID is fresh } -void ComputationState::plan_node(TaskNode &node, std::vector<TaskNode *> &to_load) { +void ComputationState::plan_node(TaskNode &node, bool load_checkpoints, std::vector<TaskNode *> &to_load) { if (node.is_planned()) { return; } node.set_planned(); - if (!node.get_task_def().checkpoint_path.empty() && loom::base::file_exists(node.get_task_def().checkpoint_path.c_str())) { + if (load_checkpoints && !node.get_task_def().checkpoint_path.empty() && loom::base::file_exists(node.get_task_def().checkpoint_path.c_str())) { node.set_checkpoint(); to_load.push_back(&node); return; @@ -50,7 +41,7 @@ void ComputationState::plan_node(TaskNode &node, std::vector<TaskNode *> &to_loa int remaining_inputs = 0; for (TaskNode *input_node : node.get_inputs()) { - plan_node(*input_node, to_load); + plan_node(*input_node, load_checkpoints, to_load); if (!input_node->is_computed()) { remaining_inputs += 1; input_node->add_next(&node); @@ -269,7 +260,7 @@ void ComputationState::make_expansion(std::vector<std::string> &configs, } }*/ -loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, std::vector<TaskNode*> &to_load) +loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints, std::vector<TaskNode*> &to_load) { auto task_size = plan.tasks_size(); assert(plan.has_id_base()); @@ -317,7 +308,7 @@ loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, std: auto new_node = std::make_unique<TaskNode>(id, std::move(def)); if (is_result) { - plan_node(*new_node.get(), to_load); + plan_node(*new_node.get(), load_checkpoints, to_load); } add_node(std::move(new_node)); } diff --git a/src/server/compstate.h b/src/server/compstate.h index 5220a59895d2bee39fcbaa020fb78ca14e3b6dc4..797042bfbdc2d355bd73b25e80344c18d318ebad 100644 --- a/src/server/compstate.h +++ b/src/server/compstate.h @@ -43,7 +43,7 @@ public: int get_n_data_objects() const; - loom::base::Id add_plan(const loom::pb::comm::Plan &plan, std::vector<TaskNode *> &to_load); + loom::base::Id add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints, std::vector<TaskNode *> &to_load); void test_ready_nodes(std::vector<loom::base::Id> ids); loom::base::Id pop_result_client_id(loom::base::Id id); @@ -61,7 +61,7 @@ public: std::unique_ptr<TaskNode> pop_node(loom::base::Id id); void clear_all(); void add_pending_node(TaskNode &node); - void plan_node(TaskNode &node, std::vector<TaskNode*> &to_load); + void plan_node(TaskNode &node, bool load_checkpoints, std::vector<TaskNode*> &to_load); private: std::unordered_map<loom::base::Id, std::unique_ptr<TaskNode>> nodes; diff --git a/src/server/taskmanager.cpp b/src/server/taskmanager.cpp index aaa0e953367f443543942ef401f2dd7831f5c5b1..aca80595fcbe96c55ca6fcc880a9bf87b8bc9e66 100644 --- a/src/server/taskmanager.cpp +++ b/src/server/taskmanager.cpp @@ -20,10 +20,10 @@ TaskManager::TaskManager(Server &server) { } -loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan) +loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints) { std::vector<TaskNode*> to_load; - loom::base::Id id_base = cstate.add_plan(plan, to_load); + loom::base::Id id_base = cstate.add_plan(plan, load_checkpoints, to_load); for (TaskNode *node : to_load) { WorkerConnection *wc = random_worker(); node->set_as_loading(wc); diff --git a/src/server/taskmanager.h b/src/server/taskmanager.h index 9b34eb23729fe5f1766dec95c60c7bba8a0ff2ec..aabf619d3241a481e4b58eba0a62c4b821c3a273 100644 --- a/src/server/taskmanager.h +++ b/src/server/taskmanager.h @@ -23,7 +23,7 @@ public: cstate.add_node(std::move(node)); }*/ - loom::base::Id add_plan(const loom::pb::comm::Plan &plan); + loom::base::Id add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints); void on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing); void on_data_transferred(loom::base::Id id, WorkerConnection *wc); diff --git a/tests/client/loomenv.py b/tests/client/loomenv.py index 698629a44fa5aaf0e59bf9282289d54437a4b1f8..4f7dde5c3ae7ac2e41f137d593706f38c2237a20 100644 --- a/tests/client/loomenv.py +++ b/tests/client/loomenv.py @@ -140,12 +140,12 @@ class LoomEnv(Env): self.check_stats() return self._client - def submit_and_gather(self, tasks, check=True): + def submit_and_gather(self, tasks, check=True, load=False): if isinstance(tasks, Task): - future = self.client.submit_one(tasks) + future = self.client.submit_one(tasks, load=load) return self.client.gather_one(future) else: - futures = self.client.submit(tasks) + futures = self.client.submit(tasks, load=load) return self.client.gather(futures) if check: self.check_final_state() diff --git a/tests/client/test_checkpoint.py b/tests/client/test_checkpoint.py index 71299aa9fbde0467eaa7c0dbf63e13dba6735e99..2c5d2c4c6986d55e4faaf30e836b45e3153c12e3 100644 --- a/tests/client/test_checkpoint.py +++ b/tests/client/test_checkpoint.py @@ -51,4 +51,4 @@ def test_checkpoint_load(loom_env): x4 = tasks.merge((x3, x1, x2, t1, t2, t3)) x4.checkpoint_path = path5 - assert loom_env.submit_and_gather(x4) == b'[4][1][2]$t3$[3][1][2]$t3$' \ No newline at end of file + assert loom_env.submit_and_gather(x4, load=True) == b'[4][1][2]$t3$[3][1][2]$t3$' \ No newline at end of file diff --git a/tests/cpp/test_scheduler.cpp b/tests/cpp/test_scheduler.cpp index 1f79e97ea863130923ff30676c14f8c5807932a5..5e8000ad9680f4d1b78d9f4a386774a513f0163c 100644 --- a/tests/cpp/test_scheduler.cpp +++ b/tests/cpp/test_scheduler.cpp @@ -319,7 +319,7 @@ static std::vector<TaskNode*> nodes(ComputationState &s, std::vector<loom::base: static void add_plan(ComputationState &s, const loom::pb::comm::Plan &plan) { std::vector<TaskNode*> to_load; - s.add_plan(plan, to_load); + s.add_plan(plan, false, to_load); assert(to_load.empty()); }