Commit 903e1fe5 authored by Stanislav Bohm's avatar Stanislav Bohm

ENH: Task redirecting

parent 8d4a02a0
......@@ -41,6 +41,7 @@ add_library(libloom
interconnect.cpp
task.cpp
task.h
taskdesc.h
sendbuffer.h
sendbuffer.cpp
loomcomm.pb.h
......
......@@ -26,7 +26,7 @@ std::shared_ptr<Data> Data::get_slice(size_t from, size_t to)
void Data::serialize(Worker &worker, SendBuffer &buffer, std::shared_ptr<Data> &data_ptr)
{
loomcomm::Data msg;
msg.set_type_id(worker.get_dictionary().lookup_symbol(get_type_name()));
msg.set_type_id(worker.get_dictionary().find_symbol_or_fail(get_type_name()));
msg.set_size(get_size());
auto length = get_length();
if (length) {
......
......@@ -74,6 +74,8 @@ protected:
std::shared_ptr<Data> data;
};
typedef std::vector<std::shared_ptr<Data>> DataVector;
}
#endif // LOOM_DATA_H
......@@ -11,7 +11,7 @@ Dictionary::Dictionary()
}
Id Dictionary::lookup_symbol(const std::string &symbol)
Id Dictionary::find_symbol_or_fail(const std::string &symbol) const
{
auto i = symbol_to_id.find(symbol);
if(i == symbol_to_id.end()) {
......@@ -22,6 +22,16 @@ Id Dictionary::lookup_symbol(const std::string &symbol)
return i->second;
}
Id Dictionary::find_symbol(const std::string &symbol) const
{
auto i = symbol_to_id.find(symbol);
if(i == symbol_to_id.end()) {
return -1;
}
assert(i->second != -1);
return i->second;
}
Id Dictionary::find_or_create(const std::string &symbol)
{
auto i = symbol_to_id.find(symbol);
......
......@@ -15,7 +15,8 @@ class Dictionary {
public:
Dictionary();
loom::Id lookup_symbol(const std::string &symbol);
loom::Id find_symbol_or_fail(const std::string &symbol) const;
loom::Id find_symbol(const std::string &symbol) const;
loom::Id find_or_create(const std::string &symbol);
const std::string& translate(loom::Id id);
......
......@@ -16,6 +16,9 @@ public:
Task(Id id, int task_type, const std::string &config)
: id(id), task_type(task_type), config(config) {}
Task(Id id, int task_type, std::string &&config)
: id(id), task_type(task_type), config(std::move(config)) {}
Id get_id() const {
return id;
}
......
#ifndef LIBLOOM_TASKREDIRECT_H
#define LIBLOOM_TASKREDIRECT_H
#include "data.h"
#include <string.h>
namespace loom {
struct TaskDescription
{
std::string task_type;
std::string config;
DataVector inputs;
};
}
#endif // LIBLOOM_TASKREDIRECT_H
......@@ -50,3 +50,8 @@ void TaskInstance::finish(const std::shared_ptr<Data> &output)
assert(output);
worker.task_finished(*this, *output);
}
void TaskInstance::redirect(std::unique_ptr<TaskDescription> tdesc)
{
worker.task_redirect(*this, std::move(tdesc));
}
......@@ -3,6 +3,7 @@
#include "data.h"
#include "task.h"
#include "taskdesc.h"
#include<uv.h>
......@@ -15,8 +16,6 @@ namespace loom {
class Worker;
class Data;
typedef std::vector<std::shared_ptr<Data>> DataVector;
/** Base class for task instance - an actual state of computation of a task */
class TaskInstance
{
......@@ -30,7 +29,7 @@ public:
}
virtual ~TaskInstance();
int get_id() const {
return task->get_id();
}
......@@ -47,6 +46,7 @@ protected:
void fail(const std::string &error_msg);
void fail_libuv(const std::string &error_msg, int error_code);
void finish(const std::shared_ptr<Data> &output);
void redirect(std::unique_ptr<TaskDescription> tdesc);
Worker &worker;
......
......@@ -2,6 +2,7 @@
#include "python.h"
#include "../data/rawdata.h"
#include "../log.h"
#include "../compat.h"
#include "python_wrapper.h"
#include <Python.h>
......@@ -31,7 +32,7 @@ void loom::PyCallTask::start(loom::DataVector &inputs)
ThreadTaskInstance::start(inputs);
}
static PyObject* vector_of_data_to_list(const DataVector &data)
static PyObject* data_vector_to_list(const DataVector &data)
{
PyObject *list = PyTuple_New(data.size());
assert(list);
......@@ -43,6 +44,54 @@ static PyObject* vector_of_data_to_list(const DataVector &data)
return list;
}
static DataVector list_to_data_vector(PyObject *obj)
{
assert(PySequence_Check(obj));
size_t size = PySequence_Size(obj);
assert(PyErr_Occurred() == nullptr);
DataVector result;
result.reserve(size);
for (size_t i = 0; i < size; i++) {
PyObject *o = PySequence_GetItem(obj, i);
assert(o);
assert(is_data_wrapper(o));
DataWrapper *data = (DataWrapper*) o;
result.push_back(data->data);
Py_DecRef(o);
}
return result;
}
static bool is_task(PyObject *obj)
{
return (PyObject_HasAttrString(obj, "task_type") &&
PyObject_HasAttrString(obj, "config") &&
PyObject_HasAttrString(obj, "inputs"));
}
static std::string get_attr_string(PyObject *obj, const char *name)
{
PyObject *value = PyObject_GetAttrString(obj, name);
assert(value);
Py_ssize_t size;
char *ptr;
if (PyUnicode_Check(value)) {
ptr = PyUnicode_AsUTF8AndSize(value, &size);
assert(ptr);
} else if (PyBytes_Check(value)) {
size = PyBytes_GET_SIZE(value);
ptr = PyBytes_AsString(value);
assert(ptr);
} else {
assert(0);
}
std::string result(ptr, size);
Py_DecRef(value);
return result;
}
std::shared_ptr<Data> PyCallTask::run()
{
// Obtain GIL
......@@ -72,7 +121,7 @@ std::shared_ptr<Data> PyCallTask::run()
task->get_config().size());
assert(config_data);
PyObject *py_inputs = vector_of_data_to_list(inputs);
PyObject *py_inputs = data_vector_to_list(inputs);
assert(py_inputs);
......@@ -111,6 +160,21 @@ std::shared_ptr<Data> PyCallTask::run()
Py_DECREF(result);
PyGILState_Release(gstate);
return output;
} else if (is_task(result)) {
// Result is task
auto task_desc = std::make_unique<TaskDescription>();
task_desc->task_type = get_attr_string(result, "task_type");
task_desc->config = get_attr_string(result, "config");
PyObject *value = PyObject_GetAttrString(result, "inputs");
assert(value);
task_desc->inputs = list_to_data_vector(value);
Py_DECREF(value);
set_redirect(std::move(task_desc));
Py_DECREF(result);
PyGILState_Release(gstate);
return nullptr;
} else {
set_error("Invalid result from python code");
......
......@@ -99,3 +99,8 @@ DataWrapper *data_wrapper_create(const std::shared_ptr<loom::Data> &data)
self->data = data;
return self;
}
bool is_data_wrapper(PyObject *obj)
{
return Py_TYPE(obj) == &data_wrapper_type;
}
......@@ -11,6 +11,7 @@ typedef struct {
} DataWrapper;
void data_wrapper_init();
bool is_data_wrapper(PyObject *obj);
DataWrapper *data_wrapper_create(const std::shared_ptr<loom::Data> &data);
#endif // LIBLOOM_TASKS_PYTHON_WRAPPER_H
......@@ -2,6 +2,8 @@
#include "worker.h"
#include <sstream>
using namespace loom;
void ThreadTaskInstance::start(DataVector &input_data)
......@@ -25,6 +27,11 @@ void ThreadTaskInstance::set_error(const std::string &error_message)
this->error_message = error_message;
}
void ThreadTaskInstance::set_redirect(std::unique_ptr<TaskDescription> tredirect)
{
task_redirect = std::move(tredirect);
}
void ThreadTaskInstance::_work_cb(uv_work_t *req)
{
ThreadTaskInstance *ttinstance = static_cast<ThreadTaskInstance*>(req->data);
......@@ -36,10 +43,12 @@ void ThreadTaskInstance::_after_work_cb(uv_work_t *req, int status)
UV_CHECK(status);
ThreadTaskInstance *ttinstance = static_cast<ThreadTaskInstance*>(req->data);
if (ttinstance->error_message.empty()) {
if (ttinstance->result) {
if (ttinstance->result && !ttinstance->task_redirect) {
ttinstance->finish(ttinstance->result);
} else if (!ttinstance->result && ttinstance->task_redirect) {
ttinstance->redirect(std::move(ttinstance->task_redirect));
} else {
ttinstance->fail("ThreadTaskInstace::run has returned nullptr");
ttinstance->fail("ThreadTaskInstace::run returned nullptr or incosistent returned values");
}
} else {
ttinstance->fail(ttinstance->error_message);
......
......@@ -35,17 +35,20 @@ public:
protected:
/** This method is called outside of main thread if run_in_thread has returned true
* IMPORTANT: It can read only member variable "inputs" and calls method "set_error"
* IMPORTANT: It can read only member variable "inputs" and calls methods
* "set_error" or "set_redirect"
* All other things are not thread-safe!
* In case of error, call set_error and return nullptr
*/
virtual std::shared_ptr<Data> run() = 0;
void set_error(const std::string &error_message);
void set_redirect(std::unique_ptr<TaskDescription> tredirect);
DataVector inputs;
uv_work_t work;
std::shared_ptr<Data> result;
std::string error_message;
std::unique_ptr<TaskDescription> task_redirect;
static void _work_cb(uv_work_t *req);
static void _after_work_cb(uv_work_t *req, int status);
......
......@@ -191,7 +191,8 @@ void Worker::new_task(std::unique_ptr<Task> task)
void Worker::start_task(std::unique_ptr<Task> task)
{
llog->debug("Starting task id={} task_type={}", task->get_id(), task->get_task_type());
llog->debug("Starting task id={} task_type={} n_inputs={}",
task->get_id(), task->get_task_type(), task->get_inputs().size());
auto i = task_factories.find(task->get_task_type());
if (unlikely(i == task_factories.end())) {
llog->critical("Task with unknown type {} received", task->get_task_type());
......@@ -348,14 +349,14 @@ std::unique_ptr<DataUnpacker> Worker::unpack(DataTypeId id)
void Worker::on_dictionary_updated()
{
for (auto &f : unregistered_task_factories) {
loom::Id id = dictionary.lookup_symbol(f->get_name());
loom::Id id = dictionary.find_symbol_or_fail(f->get_name());
llog->debug("Registering task_factory: {} = {}", f->get_name(), id);
task_factories[id] = std::move(f);
}
unregistered_task_factories.clear();
for (auto &f : unregistered_unpack_factories) {
loom::Id id = dictionary.lookup_symbol(f->get_type_name());
loom::Id id = dictionary.find_symbol_or_fail(f->get_type_name());
llog->debug("Registering unpack_factory: {} = {}", f->get_type_name(), id);
unpack_factories[id] = std::move(f);
}
......@@ -381,8 +382,11 @@ void Worker::check_waiting_tasks()
}
}
void Worker::remove_task(TaskInstance &task)
void Worker::remove_task(TaskInstance &task, bool free_resources)
{
if (free_resources) {
resource_cpus += 1;
}
for (auto i = active_tasks.begin(); i != active_tasks.end(); i++) {
if ((*i)->get_id() == task.get_id()) {
active_tasks.erase(i);
......@@ -402,10 +406,31 @@ void Worker::task_failed(TaskInstance &task, const std::string &error_msg)
msg.set_error_msg(error_msg);
server_conn.send_message(msg);
}
resource_cpus += 1;
remove_task(task);
}
void Worker::task_redirect(TaskInstance &task,
std::unique_ptr<TaskDescription> new_task_desc)
{
loom::Id id = task.get_id();
llog->debug("Redirecting task id={} task_type={} n_inputs={}",
id, new_task_desc->task_type, new_task_desc->inputs.size());
remove_task(task, false);
Id task_type_id = dictionary.find_symbol_or_fail(new_task_desc->task_type);
auto new_task = std::make_unique<Task>(id, task_type_id,
std::move(new_task_desc->config));
auto i = task_factories.find(task_type_id);
if (unlikely(i == task_factories.end())) {
llog->critical("Task with unknown type {} received", new_task->get_task_type());
assert(0);
}
auto task_instance = i->second->make_instance(*this, std::move(new_task));
TaskInstance *t = task_instance.get();
active_tasks.push_back(std::move(task_instance));
t->start(new_task_desc->inputs);
}
void Worker::task_finished(TaskInstance &task, Data &data)
{
if (server_conn.is_connected()) {
......@@ -416,7 +441,6 @@ void Worker::task_finished(TaskInstance &task, Data &data)
msg.set_length(data.get_length());
server_conn.send_message(msg);
}
resource_cpus += 1;
remove_task(task);
check_ready_tasks();
}
......
......@@ -67,6 +67,7 @@ public:
void task_finished(TaskInstance &task_instance, Data &data);
void task_failed(TaskInstance &task_instance, const std::string &error_msg);
void task_redirect(TaskInstance &task, std::unique_ptr<TaskDescription> new_task_desc);
void publish_data(Id id, const std::shared_ptr<Data> &data);
void remove_data(Id id);
......@@ -141,7 +142,7 @@ private:
void register_worker();
void start_listen();
void remove_task(TaskInstance &task);
void remove_task(TaskInstance &task, bool free_resources=true);
void start_task(std::unique_ptr<Task> task);
//int get_listen_port();
......
......@@ -27,6 +27,35 @@ def test_py_call(loom_env):
assert result2 == b"Test"
def test_py_redirect1(loom_env):
def f(a, b):
return tasks.merge((a, b))
loom_env.start(1)
c = tasks.const("ABC")
d = tasks.const("12345")
a = tasks.py_call(f, (c, d))
result = loom_env.submit(a)
assert result == b"ABC12345"
def test_py_redirect2(loom_env):
def f(a, b):
return tasks.run("/bin/ls $X", [(b, "$X")])
loom_env.start(1)
c = tasks.const("abcdef")
d = tasks.const("/")
a = tasks.py_call(f, (c, d))
result = loom_env.submit(a)
assert b"bin\n" in result
assert b"usr\n" in result
def test_py_fail_too_many_args(loom_env):
def g():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment