Skip to content
Snippets Groups Projects
Commit ada5cfb0 authored by Stanislav Bohm's avatar Stanislav Bohm
Browse files

ENH: Submit parameter "load"

parent a16f5917
Branches
Tags
No related merge requests found
...@@ -164,6 +164,7 @@ message ClientRequest { ...@@ -164,6 +164,7 @@ message ClientRequest {
// PLAN // PLAN
optional Plan plan = 2; optional Plan plan = 2;
optional bool load_checkpoints = 4;
// FETCH + RELEASE // FETCH + RELEASE
optional int32 id = 3; optional int32 id = 3;
......
...@@ -241,7 +241,7 @@ class Client(object): ...@@ -241,7 +241,7 @@ class Client(object):
print(t) print(t)
assert 0 assert 0
def submit_one(self, task): def submit_one(self, task, load=False):
"""Submits a task to the server and returns a future """Submits a task to the server and returns a future
Args: Args:
...@@ -256,9 +256,9 @@ class Client(object): ...@@ -256,9 +256,9 @@ class Client(object):
>>> result = client.submit(task3) >>> result = client.submit(task3)
>>> print(result.gather()) >>> print(result.gather())
""" """
return self.submit((task,))[0] return self.submit((task,), load=load)[0]
def submit(self, tasks): def submit(self, tasks, load=False):
"""Submits tasks to the server and returns list of futures """Submits tasks to the server and returns list of futures
Args: Args:
...@@ -294,7 +294,7 @@ class Client(object): ...@@ -294,7 +294,7 @@ class Client(object):
msg = ClientRequest() msg = ClientRequest()
msg.type = ClientRequest.PLAN msg.type = ClientRequest.PLAN
msg.load_checkpoints = load
include_metadata = self.trace_path is not None include_metadata = self.trace_path is not None
msg.plan.id_base = id_base msg.plan.id_base = id_base
plan.set_message( plan.set_message(
......
...@@ -56,8 +56,8 @@ void ClientConnection::on_message(const char *buffer, size_t size) ...@@ -56,8 +56,8 @@ void ClientConnection::on_message(const char *buffer, size_t size)
case ClientRequest_Type_PLAN: { case ClientRequest_Type_PLAN: {
logger->debug("Plan received"); logger->debug("Plan received");
const Plan &plan = request.plan(); const Plan &plan = request.plan();
loom::base::Id id_base = task_manager.add_plan(plan); loom::base::Id id_base = task_manager.add_plan(plan, request.load_checkpoints());
logger->info("Plan submitted tasks={}", plan.tasks_size()); logger->info("Plan submitted tasks={}, load_checkpoints={}", plan.tasks_size(), request.load_checkpoints());
if (server.get_trace()) { if (server.get_trace()) {
server.create_file_in_trace_dir(std::to_string(id_base) + ".plan", buffer, size); server.create_file_in_trace_dir(std::to_string(id_base) + ".plan", buffer, size);
......
...@@ -23,26 +23,17 @@ ComputationState::ComputationState(Server &server) : server(server) ...@@ -23,26 +23,17 @@ ComputationState::ComputationState(Server &server) : server(server)
void ComputationState::add_node(std::unique_ptr<TaskNode> &&node) { void ComputationState::add_node(std::unique_ptr<TaskNode> &&node) {
auto id = node->get_id(); auto id = node->get_id();
/*
for (TaskNode* input_node : node->get_inputs()) {
input_node->add_next(node.get());
}
if (node->is_ready()) {
pending_nodes.insert(node.get());
}*/
auto result = nodes.insert(std::make_pair(id, std::move(node))); auto result = nodes.insert(std::make_pair(id, std::move(node)));
assert(result.second); // Check that ID is fresh assert(result.second); // Check that ID is fresh
} }
void ComputationState::plan_node(TaskNode &node, std::vector<TaskNode *> &to_load) { void ComputationState::plan_node(TaskNode &node, bool load_checkpoints, std::vector<TaskNode *> &to_load) {
if (node.is_planned()) { if (node.is_planned()) {
return; return;
} }
node.set_planned(); node.set_planned();
if (!node.get_task_def().checkpoint_path.empty() && loom::base::file_exists(node.get_task_def().checkpoint_path.c_str())) { if (load_checkpoints && !node.get_task_def().checkpoint_path.empty() && loom::base::file_exists(node.get_task_def().checkpoint_path.c_str())) {
node.set_checkpoint(); node.set_checkpoint();
to_load.push_back(&node); to_load.push_back(&node);
return; return;
...@@ -50,7 +41,7 @@ void ComputationState::plan_node(TaskNode &node, std::vector<TaskNode *> &to_loa ...@@ -50,7 +41,7 @@ void ComputationState::plan_node(TaskNode &node, std::vector<TaskNode *> &to_loa
int remaining_inputs = 0; int remaining_inputs = 0;
for (TaskNode *input_node : node.get_inputs()) { for (TaskNode *input_node : node.get_inputs()) {
plan_node(*input_node, to_load); plan_node(*input_node, load_checkpoints, to_load);
if (!input_node->is_computed()) { if (!input_node->is_computed()) {
remaining_inputs += 1; remaining_inputs += 1;
input_node->add_next(&node); input_node->add_next(&node);
...@@ -269,7 +260,7 @@ void ComputationState::make_expansion(std::vector<std::string> &configs, ...@@ -269,7 +260,7 @@ void ComputationState::make_expansion(std::vector<std::string> &configs,
} }
}*/ }*/
loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, std::vector<TaskNode*> &to_load) loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints, std::vector<TaskNode*> &to_load)
{ {
auto task_size = plan.tasks_size(); auto task_size = plan.tasks_size();
assert(plan.has_id_base()); assert(plan.has_id_base());
...@@ -317,7 +308,7 @@ loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, std: ...@@ -317,7 +308,7 @@ loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, std:
auto new_node = std::make_unique<TaskNode>(id, std::move(def)); auto new_node = std::make_unique<TaskNode>(id, std::move(def));
if (is_result) { if (is_result) {
plan_node(*new_node.get(), to_load); plan_node(*new_node.get(), load_checkpoints, to_load);
} }
add_node(std::move(new_node)); add_node(std::move(new_node));
} }
......
...@@ -43,7 +43,7 @@ public: ...@@ -43,7 +43,7 @@ public:
int get_n_data_objects() const; int get_n_data_objects() const;
loom::base::Id add_plan(const loom::pb::comm::Plan &plan, std::vector<TaskNode *> &to_load); loom::base::Id add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints, std::vector<TaskNode *> &to_load);
void test_ready_nodes(std::vector<loom::base::Id> ids); void test_ready_nodes(std::vector<loom::base::Id> ids);
loom::base::Id pop_result_client_id(loom::base::Id id); loom::base::Id pop_result_client_id(loom::base::Id id);
...@@ -61,7 +61,7 @@ public: ...@@ -61,7 +61,7 @@ public:
std::unique_ptr<TaskNode> pop_node(loom::base::Id id); std::unique_ptr<TaskNode> pop_node(loom::base::Id id);
void clear_all(); void clear_all();
void add_pending_node(TaskNode &node); void add_pending_node(TaskNode &node);
void plan_node(TaskNode &node, std::vector<TaskNode*> &to_load); void plan_node(TaskNode &node, bool load_checkpoints, std::vector<TaskNode*> &to_load);
private: private:
std::unordered_map<loom::base::Id, std::unique_ptr<TaskNode>> nodes; std::unordered_map<loom::base::Id, std::unique_ptr<TaskNode>> nodes;
......
...@@ -20,10 +20,10 @@ TaskManager::TaskManager(Server &server) ...@@ -20,10 +20,10 @@ TaskManager::TaskManager(Server &server)
{ {
} }
loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan) loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints)
{ {
std::vector<TaskNode*> to_load; std::vector<TaskNode*> to_load;
loom::base::Id id_base = cstate.add_plan(plan, to_load); loom::base::Id id_base = cstate.add_plan(plan, load_checkpoints, to_load);
for (TaskNode *node : to_load) { for (TaskNode *node : to_load) {
WorkerConnection *wc = random_worker(); WorkerConnection *wc = random_worker();
node->set_as_loading(wc); node->set_as_loading(wc);
......
...@@ -23,7 +23,7 @@ public: ...@@ -23,7 +23,7 @@ public:
cstate.add_node(std::move(node)); cstate.add_node(std::move(node));
}*/ }*/
loom::base::Id add_plan(const loom::pb::comm::Plan &plan); loom::base::Id add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints);
void on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing); void on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing);
void on_data_transferred(loom::base::Id id, WorkerConnection *wc); void on_data_transferred(loom::base::Id id, WorkerConnection *wc);
......
...@@ -140,12 +140,12 @@ class LoomEnv(Env): ...@@ -140,12 +140,12 @@ class LoomEnv(Env):
self.check_stats() self.check_stats()
return self._client return self._client
def submit_and_gather(self, tasks, check=True): def submit_and_gather(self, tasks, check=True, load=False):
if isinstance(tasks, Task): if isinstance(tasks, Task):
future = self.client.submit_one(tasks) future = self.client.submit_one(tasks, load=load)
return self.client.gather_one(future) return self.client.gather_one(future)
else: else:
futures = self.client.submit(tasks) futures = self.client.submit(tasks, load=load)
return self.client.gather(futures) return self.client.gather(futures)
if check: if check:
self.check_final_state() self.check_final_state()
......
...@@ -51,4 +51,4 @@ def test_checkpoint_load(loom_env): ...@@ -51,4 +51,4 @@ def test_checkpoint_load(loom_env):
x4 = tasks.merge((x3, x1, x2, t1, t2, t3)) x4 = tasks.merge((x3, x1, x2, t1, t2, t3))
x4.checkpoint_path = path5 x4.checkpoint_path = path5
assert loom_env.submit_and_gather(x4) == b'[4][1][2]$t3$[3][1][2]$t3$' assert loom_env.submit_and_gather(x4, load=True) == b'[4][1][2]$t3$[3][1][2]$t3$'
\ No newline at end of file \ No newline at end of file
...@@ -319,7 +319,7 @@ static std::vector<TaskNode*> nodes(ComputationState &s, std::vector<loom::base: ...@@ -319,7 +319,7 @@ static std::vector<TaskNode*> nodes(ComputationState &s, std::vector<loom::base:
static void add_plan(ComputationState &s, const loom::pb::comm::Plan &plan) { static void add_plan(ComputationState &s, const loom::pb::comm::Plan &plan) {
std::vector<TaskNode*> to_load; std::vector<TaskNode*> to_load;
s.add_plan(plan, to_load); s.add_plan(plan, false, to_load);
assert(to_load.empty()); assert(to_load.empty());
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment