diff --git a/src/server/compstate.h b/src/server/compstate.h index 797042bfbdc2d355bd73b25e80344c18d318ebad..d8697b95fb0191fdca6632d6e68527b1b7b0624f 100644 --- a/src/server/compstate.h +++ b/src/server/compstate.h @@ -62,7 +62,7 @@ public: void clear_all(); void add_pending_node(TaskNode &node); void plan_node(TaskNode &node, bool load_checkpoints, std::vector<TaskNode*> &to_load); - + void fail_task_on_worker(WorkerConnection &conn); private: std::unordered_map<loom::base::Id, std::unique_ptr<TaskNode>> nodes; std::unordered_set<TaskNode*> pending_nodes; diff --git a/src/server/server.cpp b/src/server/server.cpp index 9bf63f69b84371445f4a7bd1f369df8d4971e616..83bba35760d8a01ff4c5c16cf698692915c1a0f1 100644 --- a/src/server/server.cpp +++ b/src/server/server.cpp @@ -48,6 +48,7 @@ void Server::add_worker_connection(std::unique_ptr<WorkerConnection> &&conn) void Server::remove_worker_connection(WorkerConnection &conn) { + task_manager.worker_fail(conn); auto i = std::find_if( connections.begin(), connections.end(), diff --git a/src/server/taskmanager.cpp b/src/server/taskmanager.cpp index aca80595fcbe96c55ca6fcc880a9bf87b8bc9e66..f1c85c0b615fff7607131f436aa8b52c1868e334 100644 --- a/src/server/taskmanager.cpp +++ b/src/server/taskmanager.cpp @@ -331,6 +331,15 @@ void TaskManager::release_node(TaskNode *node) } } +void TaskManager::worker_fail(WorkerConnection &conn) +{ + auto cc = server.get_client_connection(); + if (cc) { + cc->send_task_failed(-1, conn, "Worker lost"); + } + trash_all_tasks(); +} + WorkerConnection *TaskManager::random_worker() { auto &connections = server.get_connections(); diff --git a/src/server/taskmanager.h b/src/server/taskmanager.h index aabf619d3241a481e4b58eba0a62c4b821c3a273..6448aebb4fefe4034a02e181b02e33de7f6d588b 100644 --- a/src/server/taskmanager.h +++ b/src/server/taskmanager.h @@ -46,6 +46,9 @@ public: void trash_all_tasks(); void release_node(TaskNode *node); + void fail_task_on_worker(WorkerConnection &conn); + void worker_fail(WorkerConnection &conn); + WorkerConnection *random_worker(); private: diff --git a/tests/client/loomenv.py b/tests/client/loomenv.py index 4f7dde5c3ae7ac2e41f137d593706f38c2237a20..5234b2c0a5cdec030b6c79f5aa118093dff277d8 100644 --- a/tests/client/loomenv.py +++ b/tests/client/loomenv.py @@ -55,9 +55,16 @@ class Env(): def kill_all(self): for fn in self.cleanups: fn() - for n, p in self.processes: + for _, p in self.processes: p.kill() + def kill(self, name): + for n, p in self.processes: + if n == name: + p.kill() + return + raise Exception("Unknown processes") + class LoomEnv(Env): @@ -115,6 +122,12 @@ class LoomEnv(Env): assert stats["n_workers"] == self.workers_count assert stats["n_data_objects"] == 0 + def kill_worker(self, id): + assert self.workers_count > 0 + self.kill("worker{}".format(id)) + self.workers_count -= 1 + time.sleep(0.02) + def check_final_state(self): time.sleep(0.25) self.check_stats() diff --git a/tests/client/test_fail.py b/tests/client/test_fail.py index 20a0ea00345ca4ad52b09467b84205fc7070410b..30c72a247dba438ab16bc0217a9f38444602eec0 100644 --- a/tests/client/test_fail.py +++ b/tests/client/test_fail.py @@ -1,5 +1,6 @@ from loomenv import loom_env, LOOM_TESTPROG, LOOM_TEST_DATA_DIR # noqa import loom.client.tasks as tasks # noqa +import time from loom import client import pytest @@ -119,3 +120,30 @@ def test_fail_and_report(loom_env): a = tasks.const("ABC") with pytest.raises(client.TaskFailed): loom_env.submit_and_gather((sleep(), sleep(), sleep(), fail(a))) + + +def test_crash_clean_worker(loom_env): + loom_env.start(2) + loom_env.kill_worker(0) + + a = tasks.const("ABCDE") + b = tasks.const("123") + c = tasks.merge((a, b)) + assert b"ABCDE123" == loom_env.submit_and_gather(c) + + +def test_crash_running_worker(loom_env): + + @tasks.py_task() + def sleep(): + import time + time.sleep(1) + return b"" + + loom_env.start(2) + a = sleep() + b = sleep() + (fa, fb) = loom_env.client.submit((a, b)) + time.sleep(0.3) + loom_env.kill_worker(0) + loom_env.client.gather((fa, fb))