Skip to content
Snippets Groups Projects
worker.cpp 14.21 KiB
#include "worker.h"
#include "loomcomm.pb.h"
#include "utils.h"
#include "log.h"
#include "types.h"

#include "data/rawdata.h"
#include "data/array.h"
#include "data/index.h"

#include "tasks/basetasks.h"
#include "tasks/rawdatatasks.h"
#include "tasks/arraytasks.h"
#include "tasks/runtask.h"

#include <stdlib.h>
#include <sstream>
#include <unistd.h>

using namespace loom;


Worker::Worker(uv_loop_t *loop,
               const std::string &server_address,
               int server_port,
               const std::string &work_dir_root)
    : loop(loop),
      server_conn(*this),
      server_port(server_port)
{
    GOOGLE_PROTOBUF_VERIFY_VERSION;
    UV_CHECK(uv_tcp_init(loop, &listen_socket));
    listen_socket.data = this;
    start_listen();

    if (!server_address.empty()) {
        llog->info("Connecting to server {}:{}", server_address, server_port);
        uv_getaddrinfo_t* handle = new uv_getaddrinfo_t;
        handle->data = this;
        struct addrinfo hints;
        memset(&hints,0,sizeof(hints));
        hints.ai_family = AF_UNSPEC;
        hints.ai_socktype = SOCK_STREAM;
        UV_CHECK(uv_getaddrinfo(
            loop, handle, _on_getaddrinfo, server_address.c_str(), "80", &hints));
    }

    if (!work_dir_root.empty()) {
        std::stringstream s;
        s << work_dir_root;
        if (work_dir_root.back() != '/') {
            s << "/";
        }
        char tmp[100];
        if (gethostname(tmp, 100)) {
            llog->error("Cannot get hostname, using 'nohostname'");
            strcpy(tmp, "nohostname");
        }
        s << "worker-" << tmp << '-' << listen_port << '/';
        work_dir = s.str();

        if (make_path(work_dir.c_str(), S_IRWXU)) {
            llog->critical("Cannot create working directory '{}'", work_dir);
            exit(1);
        }

        if (mkdir((work_dir + "data").c_str(), S_IRWXU)) {
            llog->critical("Cannot create 'data' in working directory");
            exit(1);
        }

        llog->info("Using '{}' as working directory", work_dir);
    }

    add_unpacker<RawDataUnpacker>();
    add_unpacker<ArrayUnpacker>();
    add_unpacker<IndexUnpacker>();

    resource_cpus = 1;
}

void Worker::register_basic_tasks()
{
    // Base
    add_task_factory<GetTask>("base/get");
    add_task_factory<SliceTask>("base/slice");

    // RawData
    add_task_factory<ConstTask>("data/const");
    add_task_factory<MergeTask>("data/merge");
    add_task_factory<OpenTask>("data/open");
    add_task_factory<SplitTask>("data/split");

    // Arrays
    add_task_factory<ArrayMakeTask>("array/make");

    // Run
    add_task_factory<RunTask>("run/run");
}


void Worker::_on_getaddrinfo(uv_getaddrinfo_t* handle, int status,
        struct addrinfo* response) {
    Worker *worker = static_cast<Worker*>(handle->data);
    if (status) {
        llog->critical("Cannot resolve server name");
        uv_freeaddrinfo(response);
        delete handle;
        worker->close_all();
        exit(1);
    }
    assert(response->ai_family == AF_INET);
    char tmp[60];
    UV_CHECK(uv_ip4_name((struct sockaddr_in*) response->ai_addr, tmp, 60));
    worker->server_address = tmp;

    uv_freeaddrinfo(response);
    delete handle;

    llog->debug("Server address resolved to {}", worker->server_address);
    worker->server_conn.connect(worker->server_address, worker->server_port);
}

void Worker::_on_new_connection(uv_stream_t *stream, int status)
{
    UV_CHECK(status);
    Worker *worker = static_cast<Worker*>(stream->data);
    auto connection = std::make_unique<InterConnection>(*worker);
    connection->accept(&worker->listen_socket);
    llog->debug("Worker connection from {}", connection->get_peername());
    worker->add_connection(std::move(connection));
}

void Worker::start_listen()
{
    struct sockaddr_in addr;
    UV_CHECK(uv_ip4_addr("0.0.0.0", 0, &addr));

    UV_CHECK(uv_tcp_bind(&listen_socket, (const struct sockaddr *) &addr, 0));
    UV_CHECK(uv_listen((uv_stream_t *) &listen_socket, 10, _on_new_connection));

    struct sockaddr_in sockname;
    int namelen = sizeof(sockname);
    uv_tcp_getsockname(&listen_socket, (sockaddr*) &sockname, &namelen);
    listen_port = ntohs(sockname.sin_port);
}

/*int Worker::get_listen_port()
{
    struct sockaddr_in sa;
    int name_len = sizeof(sa);
    UV_CHECK(uv_tcp_getsockname(&listen_socket, (sockaddr *) &sa, &name_len));
    return ntohs(sa.sin_port);
}*/

void Worker::register_worker()
{
    loomcomm::Register msg;
    msg.set_type(loomcomm::Register_Type_REGISTER_WORKER);
    msg.set_protocol_version(PROTOCOL_VERSION);

    msg.set_port(get_listen_port());
    msg.set_cpus(resource_cpus);

    for (auto& factory : unregistered_task_factories) {
        msg.add_task_types(factory->get_name());
    }

    for (auto& factory : unregistered_unpack_factories) {
        msg.add_data_types(factory->get_type_name());
    }

    server_conn.send_message(msg);
}

void Worker::new_task(std::unique_ptr<Task> task)
{
    if (task->is_ready(*this)) {
        ready_tasks.push_back(std::move(task));
        check_ready_tasks();
        return;
    }
    waiting_tasks.push_back(std::move(task));
}

void Worker::start_task(std::unique_ptr<Task> task)
{
    llog->debug("Starting task id={} task_type={}", task->get_id(), task->get_task_type());
    auto i = task_factories.find(task->get_task_type());
    if (unlikely(i == task_factories.end())) {
        llog->critical("Task with unknown type {} received", task->get_task_type());
        exit(1);
    }
    auto task_instance = i->second->make_instance(*this, std::move(task));
    TaskInstance *t = task_instance.get();
    active_tasks.push_back(std::move(task_instance));

    DataVector input_data;
    for (Id id : t->get_inputs()) {
        input_data.push_back(&get_data(id));
    }

    t->start(input_data);
    resource_cpus -= 1;
}

/*void Worker::new_task(const Task &task)
{
    auto task_type = task.get_task_type();
    assert(task_type >= 0 && task_type < (int) task_factories.size());
    auto task_instance = task_factories[task_type]->make_instance(*this, task);
    TaskInstance *t = task_instance.get();
    active_tasks.push_back(std::move(task_instance));
    t->start();
}*/

void Worker::publish_data(Id id, std::shared_ptr<Data> &data)
{
    llog->debug("Publishing data id={} size={}", id, data->get_size());
    public_data[id] = data;
    check_waiting_tasks();
}

void Worker::remove_data(Id id)
{
    llog->debug("Removing data id={}", id);
    auto i = public_data.find(id);
    assert(i != public_data.end());
    public_data.erase(i);
}

InterConnection& Worker::get_connection(const std::string &address)
{
    auto &connection = connections[address];
    if (connection.get() == nullptr) {
        llog->info("Connecting to {}", address);
        connection = std::make_unique<InterConnection>(*this);

        std::stringstream ss(address);
        std::string base_address;
        std::getline(ss, base_address, ':');
        int port;
        ss >> port;

        if (base_address == "!server") {
            connection->connect(server_address, port);
        } else {
            connection->connect(base_address, port);
        }
    }
    return *connection;
}

void Worker::close_all()
{
    uv_close((uv_handle_t*) &listen_socket, NULL);
    server_conn.close();
    for (auto& pair : connections) {
        pair.second->close();
    }
    for (auto& c : nonregistered_connections) {
        c->close();
    }
}

void Worker::register_connection(InterConnection &connection)
{    
    auto &c = connections[connection.get_address()];
    if (unlikely(c.get() != nullptr)) {
        // This can happen when two workers connect each other in the same time
        llog->debug("Registration collision, leaving unregisted");
        // It is ok to leave it as it be, we will just hold the redundant connection
        // in unregistered connections
        return;
    }
    auto i = std::find_if(
                nonregistered_connections.begin(),
                nonregistered_connections.end(),
                [&](std::unique_ptr<InterConnection>& p) { return p.get() == &connection; });
    assert(i != nonregistered_connections.end());
    c = std::move(*i);
    nonregistered_connections.erase(i);
}

void Worker::unregister_connection(InterConnection &connection)
{
    const auto &i = connections.find(connection.get_address());
    if (unlikely(i == connections.end())) {
        auto i = std::find_if(
                    nonregistered_connections.begin(),
                    nonregistered_connections.end(),
                    [&](std::unique_ptr<InterConnection>& p) { return p.get() == &connection; });
        assert(i != nonregistered_connections.end());
        nonregistered_connections.erase(i);
        return;
    }
    connections.erase(i);
}

std::string Worker::get_run_dir(Id id)
{
    std::stringstream s;
    s << work_dir << "run/" << id << "/";
    return s.str();
}

void Worker::check_ready_tasks()
{
    while (resource_cpus > 0 && ready_tasks.size()) {
        auto task = std::move(ready_tasks[0]);
        ready_tasks.erase(ready_tasks.begin());
        start_task(std::move(task));
    }
}

void Worker::set_cpus(int value)
{
    if (value == 0) {
        value = sysconf(_SC_NPROCESSORS_ONLN);
        llog->debug("Autodetection of CPUs: {}", value);
    }
    if (value <= 0) {
        value = 1;
    }

    resource_cpus = value;
    llog->info("Number of CPUs for worker: {}", value);
}

void Worker::add_unpacker(std::unique_ptr<UnpackFactory> factory)
{
    unregistered_unpack_factories.push_back(std::move(factory));
}

std::unique_ptr<DataUnpacker> Worker::unpack(DataTypeId id)
{
    auto i = unpack_factories.find(id);
    assert(i != unpack_factories.end());
    return i->second->make_unpacker();
}

void Worker::on_dictionary_updated()
{
    for (auto &f : unregistered_task_factories) {
        loom::Id id = dictionary.lookup_symbol(f->get_name());
        llog->debug("Registering task_factory: {} = {}", f->get_name(), id);
        task_factories[id] = std::move(f);
    }
    unregistered_task_factories.clear();
    for (auto &f : unregistered_unpack_factories) {
        loom::Id id = dictionary.lookup_symbol(f->get_type_name());
        llog->debug("Registering unpack_factory: {} = {}", f->get_type_name(), id);
        unpack_factories[id] = std::move(f);
    }
    unregistered_unpack_factories.clear();
}

void Worker::check_waiting_tasks()
{
    bool something_new = false;
    auto i = waiting_tasks.begin();
    while (i != waiting_tasks.end()) {
        auto& task_ptr = *i;
        if (task_ptr->is_ready(*this)) {
            ready_tasks.push_back(std::move(task_ptr));
            i = waiting_tasks.erase(i);
            something_new = true;
        } else {
            ++i;
        }
    }
    if (something_new) {
        check_ready_tasks();
    }
}

void Worker::remove_task(TaskInstance &task)
{
    for (auto i = active_tasks.begin(); i != active_tasks.end(); i++) {
        if ((*i)->get_id() == task.get_id()) {
            active_tasks.erase(i);
            return;
        }
    }
    assert(0);
}

void Worker::task_failed(TaskInstance &task, const std::string &error_msg)
{
    llog->error("Task id={} failed: {}", task.get_id(), error_msg);
    if (server_conn.is_connected()) {
        loomcomm::WorkerResponse msg;
        msg.set_type(loomcomm::WorkerResponse_Type_FAILED);
        msg.set_id(task.get_id());
        msg.set_error_msg(error_msg);
        server_conn.send_message(msg);
    }
    resource_cpus += 1;
    remove_task(task);
}

void Worker::task_finished(TaskInstance &task, Data &data)
{
    if (server_conn.is_connected()) {
        loomcomm::WorkerResponse msg;
        msg.set_type(loomcomm::WorkerResponse_Type_FINISH);
        msg.set_id(task.get_id());
        msg.set_size(data.get_size());
        msg.set_length(data.get_length());
        server_conn.send_message(msg);
    }
    resource_cpus += 1;
    remove_task(task);
    check_ready_tasks();
}

void Worker::send_data(const std::string &address, Id id, std::shared_ptr<Data> &data, bool with_size)
{
    auto &connection = get_connection(address);;
    connection.send(id, data, with_size);
}

ServerConnection::ServerConnection(Worker &worker)
    : SimpleConnectionCallback(worker.get_loop()),
      worker(worker)
{

}

void ServerConnection::on_connection()
{
    connection.start_read();
    worker.register_worker();
}

void ServerConnection::on_close()
{
    llog->critical("Connection to server is closed. Terminating ...");
    worker.close_all();
}

void ServerConnection::on_error(int error_code)
{
    llog->critical("Server connection error: {}", uv_strerror(error_code));
    connection.close();
}

void ServerConnection::on_message(const char *data, size_t size)
{
    loomcomm::WorkerCommand msg;
    assert(msg.ParseFromArray(data, size));
    auto type = msg.type();

    switch (type) {
    case loomcomm::WorkerCommand_Type_TASK: {
        llog->debug("Task id={} received", msg.id());
        auto task = std::make_unique<Task>(msg.id(),
                                           msg.task_type(),
                                           msg.task_config());
        for (int i = 0; i < msg.task_inputs_size(); i++) {
            task->add_input(msg.task_inputs(i));
        }
        worker.new_task(std::move(task));
        break;
    }
    case loomcomm::WorkerCommand_Type_REMOVE: {
        worker.remove_data(msg.id());
        break;
    }
    case loomcomm::WorkerCommand_Type_SEND: {
        auto& address = msg.address();
        /* "!" means address to server, so we replace the sign to proper address */
        if (address.size() > 2 && address[0] == '!' && address[1] == ':') {
            msg.set_address(worker.get_server_address() + ":" + address.substr(2, std::string::npos));
        }
        llog->debug("Sending data {} to {}", msg.id(), msg.address());
        bool with_size = msg.has_with_size() && msg.with_size();
        assert(worker.send_data(msg.address(), msg.id(), with_size));
        break;
    }
    case loomcomm::WorkerCommand_Type_DICTIONARY: {
        auto count = msg.symbols_size();
        llog->debug("New dictionary ({} symbols)", count);
        Dictionary &dictionary = worker.get_dictionary();
        for (int i = 0; i < count; i++) {
            dictionary.find_or_create(msg.symbols(i));
        }
        worker.on_dictionary_updated();
    } break;
    default:
        llog->critical("Invalid message");
        exit(1);
    }
}