Commit e98ad40e authored by Stanislav Bohm's avatar Stanislav Bohm

ENH: Basic error propagation

parent 37ba5dcb
from .client import Client # noqa
from .client import Client, LoomException, TaskFailed # noqa
from .plan import Plan # noqa
......@@ -7,6 +7,20 @@ from plan import Task
LOOM_PROTOCOL_VERSION = 1
class LoomException(Exception):
pass
class TaskFailed(LoomException):
def __init__(self, id, worker, error_msg):
self.id = id
self.worker = worker
self.error_msg = error_msg
message = "Task id={} failed: {}".format(id, error_msg)
LoomException.__init__(self, message)
class Client(object):
def __init__(self, address, port, info=False):
......@@ -51,15 +65,21 @@ class Client(object):
if cmsg.type == ClientMessage.DATA:
prologue = cmsg.data
data[prologue.id] = self._receive_data()
else:
assert cmsg.type == ClientMessage.INFO
elif cmsg.type == ClientMessage.INFO:
self.add_info(cmsg.info)
elif cmsg.type == ClientMessage.ERROR:
self.process_error(cmsg)
if single_result:
return data[results.id]
else:
return [data[task.id] for task in results]
def process_error(self, cmsg):
assert cmsg.HasField("error")
error = cmsg.error
raise TaskFailed(error.id, error.worker, error.error_msg)
def add_info(self, info):
self.info.append((info.id, info.worker))
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -26,6 +26,18 @@ const std::string TaskInstance::get_task_dir()
return name;
}
void TaskInstance::fail(const std::string &error_msg)
{
worker.task_failed(*this, error_msg);
}
void TaskInstance::fail_libuv(const std::string &error_msg, int error_code)
{
std::stringstream s;
s << error_msg << ": " << uv_strerror(error_code);
fail(s.str());
}
void TaskInstance::finish(std::shared_ptr<Data> &output)
{
worker.publish_data(get_id(), output);
......
......@@ -43,6 +43,8 @@ public:
virtual void start(DataVector &input_data) = 0;
protected:
void fail(const std::string &error_msg);
void fail_libuv(const std::string &error_msg, int error_code);
void finish(std::shared_ptr<Data> &output);
void finish_without_data();
......
......@@ -160,7 +160,10 @@ void Worker::start_task(std::unique_ptr<Task> task)
{
llog->debug("Starting task id={} task_type={}", task->get_id(), task->get_task_type());
auto i = task_factories.find(task->get_task_type());
assert(i != task_factories.end());
if (unlikely(i == task_factories.end())) {
llog->critical("Task with unknown type received");
exit(1);
}
auto task_instance = i->second->make_instance(*this, std::move(task));
TaskInstance *t = task_instance.get();
active_tasks.push_back(std::move(task_instance));
......@@ -349,10 +352,25 @@ void Worker::remove_task(TaskInstance &task)
assert(0);
}
void Worker::task_failed(TaskInstance &task, const std::string &error_msg)
{
llog->error("Task id={} failed: {}", task.get_id(), error_msg);
if (server_conn.is_connected()) {
loomcomm::WorkerResponse msg;
msg.set_type(loomcomm::WorkerResponse_Type_FAILED);
msg.set_id(task.get_id());
msg.set_error_msg(error_msg);
server_conn.send_message(msg);
}
resource_cpus += 1;
remove_task(task);
}
void Worker::task_finished(TaskInstance &task)
{
if (server_conn.is_connected()) {
loomcomm::WorkerResponse msg;
msg.set_type(loomcomm::WorkerResponse_Type_FINISH);
msg.set_id(task.get_id());
server_conn.send_message(msg);
}
......
......@@ -63,6 +63,7 @@ public:
}
void task_finished(TaskInstance &task_instance);
void task_failed(TaskInstance &task_instance, const std::string &error_msg);
void publish_data(Id id, std::shared_ptr<Data> &data);
void remove_data(Id id);
......
......@@ -51,7 +51,15 @@ message WorkerCommand {
}
message WorkerResponse {
optional int32 id = 2;
enum Type {
FINISH = 1;
FAILED = 2;
}
required Type type = 1;
required int32 id = 2;
// FAILED
optional string error_msg = 3;
}
message Announce {
......@@ -75,12 +83,20 @@ message Info {
required string worker = 2;
}
message Error {
required int32 id = 1;
required string worker = 2;
required string error_msg = 3;
}
message ClientMessage {
enum Type {
DATA = 1;
INFO = 2;
ERROR = 3;
}
required Type type = 1;
optional DataPrologue data = 2;
optional Info info = 3;
optional Error error = 4;
}
......@@ -84,6 +84,32 @@ void Server::on_task_finished(TaskNode &task)
task_manager.on_task_finished(task);
}
void Server::inform_about_error(std::string &error_msg)
{
}
void Server::inform_about_task_error(Id id, WorkerConnection &wconn, const std::string &error_msg)
{
llog->error("Task id={} failed on worker {}: {}",
id, wconn.get_address(), error_msg);
loomcomm::ClientMessage msg;
msg.set_type(loomcomm::ClientMessage_Type_ERROR);
loomcomm::Error *error = msg.mutable_error();
error->set_id(id);
error->set_worker(wconn.get_address());
error->set_error_msg(error_msg);
if (client_connection) {
SendBuffer *buffer = new SendBuffer();
buffer->add(msg);
client_connection->send_buffer(buffer);
}
exit(1);
}
void Server::start_listen()
{
struct sockaddr_in addr;
......
......@@ -56,6 +56,9 @@ public:
return dictionary;
}
void inform_about_error(std::string &error_msg);
void inform_about_task_error(loom::Id id, WorkerConnection &wconn, const std::string &error_msg);
private:
void start_listen();
......
......@@ -72,7 +72,6 @@ void TaskManager::add_plan(const loomplan::Plan &plan, bool distribute)
}
for (auto &t : tasks) {
llog->alert("{} {}", t.second->get_id(), t.second->get_ref_counter());
assert(t.second->get_ref_counter() > 0);
}
......
......@@ -51,13 +51,21 @@ void WorkerConnection::on_message(const char *buffer, size_t size)
msg.ParseFromArray(buffer, size);
const auto it = tasks.find(msg.id());
assert(it != tasks.end());
assert(it != tasks.end());
TaskNode *task = it->second;
tasks.erase(it);
task->add_owner(this);
task->set_finished();
server.on_task_finished(*task);
if (msg.type() == loomcomm::WorkerResponse_Type_FINISH) {
task->add_owner(this);
task->set_finished();
server.on_task_finished(*task);
return;
}
if (msg.type() == loomcomm::WorkerResponse_Type_FAILED) {
assert(msg.has_error_msg());
server.inform_about_task_error(msg.id(), *this, msg.error_msg());
}
}
void WorkerConnection::on_close()
......
......@@ -66,7 +66,6 @@ void RunTask::start(DataVector &inputs)
stdio[1].flags = UV_IGNORE;
options.stdio_count = 2;
llog->alert("{} {}", msg.map_inputs_size(), inputs.size());
assert(msg.map_inputs_size() <= static_cast<int>(inputs.size()));
for (int i = 0; i < msg.map_inputs_size(); i++) {
......@@ -105,7 +104,8 @@ void RunTask::start(DataVector &inputs)
}
}
UV_CHECK(uv_spawn(worker.get_loop(), &process, &options));
int r;
r = uv_spawn(worker.get_loop(), &process, &options);
process.data = this;
/* Cleanup */
......@@ -115,6 +115,11 @@ void RunTask::start(DataVector &inputs)
if (stdio[1].flags == UV_INHERIT_FD) {
close(stdio[1].data.fd);
}
if (r) {
fail_libuv("uv_spawn", r);
return;
}
}
std::string RunTask::get_path(const std::string &filename)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment