Commit 9be6b0ca authored by Stanislav Bohm's avatar Stanislav Bohm

ENH: Dictionary of symbols added

parent d18fe288
......@@ -18,7 +18,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='loomcomm.proto',
package='loomcomm',
serialized_pb=_b('\n\x0eloomcomm.proto\x12\x08loomcomm\"\xbb\x01\n\x08Register\x12\x18\n\x10protocol_version\x18\x01 \x02(\x05\x12%\n\x04type\x18\x02 \x02(\x0e\x32\x17.loomcomm.Register.Type\x12\x0c\n\x04port\x18\x03 \x01(\x05\x12\x12\n\ntask_types\x18\x04 \x03(\t\x12\x0c\n\x04\x63pus\x18\x05 \x01(\x05\x12\x0c\n\x04info\x18\n \x01(\x08\"0\n\x04Type\x12\x13\n\x0fREGISTER_WORKER\x10\x01\x12\x13\n\x0fREGISTER_CLIENT\x10\x02\"&\n\rServerMessage\"\x15\n\x04Type\x12\r\n\tSTART_JOB\x10\x01\"\xd0\x01\n\rWorkerCommand\x12*\n\x04type\x18\x01 \x02(\x0e\x32\x1c.loomcomm.WorkerCommand.Type\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x11\n\ttask_type\x18\x03 \x01(\x05\x12\x13\n\x0btask_config\x18\x04 \x01(\t\x12\x13\n\x0btask_inputs\x18\x05 \x03(\x05\x12\x0f\n\x07\x61\x64\x64ress\x18\n \x01(\t\x12\x11\n\twith_size\x18\x0b \x01(\x08\"&\n\x04Type\x12\x08\n\x04TASK\x10\x01\x12\x08\n\x04SEND\x10\x02\x12\n\n\x06REMOVE\x10\x03\"\x1c\n\x0eWorkerResponse\x12\n\n\x02id\x18\x02 \x01(\x05\"\x18\n\x08\x41nnounce\x12\x0c\n\x04port\x18\x01 \x02(\x05\"-\n\x0c\x44\x61taPrologue\x12\n\n\x02id\x18\x01 \x02(\x05\x12\x11\n\tdata_size\x18\x03 \x01(\x04\"%\n\x04\x44\x61ta\x12\x0f\n\x07type_id\x18\x01 \x02(\x05\x12\x0c\n\x04size\x18\x02 \x01(\x04\"\"\n\x04Info\x12\n\n\x02id\x18\x01 \x02(\x05\x12\x0e\n\x06worker\x18\x02 \x02(\t\"\x9b\x01\n\rClientMessage\x12*\n\x04type\x18\x01 \x02(\x0e\x32\x1c.loomcomm.ClientMessage.Type\x12$\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x16.loomcomm.DataPrologue\x12\x1c\n\x04info\x18\x03 \x01(\x0b\x32\x0e.loomcomm.Info\"\x1a\n\x04Type\x12\x08\n\x04\x44\x41TA\x10\x01\x12\x08\n\x04INFO\x10\x02\x42\x02H\x03')
serialized_pb=_b('\n\x0eloomcomm.proto\x12\x08loomcomm\"\xbb\x01\n\x08Register\x12\x18\n\x10protocol_version\x18\x01 \x02(\x05\x12%\n\x04type\x18\x02 \x02(\x0e\x32\x17.loomcomm.Register.Type\x12\x0c\n\x04port\x18\x03 \x01(\x05\x12\x12\n\ntask_types\x18\x04 \x03(\t\x12\x0c\n\x04\x63pus\x18\x05 \x01(\x05\x12\x0c\n\x04info\x18\n \x01(\x08\"0\n\x04Type\x12\x13\n\x0fREGISTER_WORKER\x10\x01\x12\x13\n\x0fREGISTER_CLIENT\x10\x02\"&\n\rServerMessage\"\x15\n\x04Type\x12\r\n\tSTART_JOB\x10\x01\"\xf1\x01\n\rWorkerCommand\x12*\n\x04type\x18\x01 \x02(\x0e\x32\x1c.loomcomm.WorkerCommand.Type\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x11\n\ttask_type\x18\x03 \x01(\x05\x12\x13\n\x0btask_config\x18\x04 \x01(\t\x12\x13\n\x0btask_inputs\x18\x05 \x03(\x05\x12\x0f\n\x07\x61\x64\x64ress\x18\n \x01(\t\x12\x11\n\twith_size\x18\x0b \x01(\x08\x12\x0f\n\x07symbols\x18\x64 \x03(\t\"6\n\x04Type\x12\x08\n\x04TASK\x10\x01\x12\x08\n\x04SEND\x10\x02\x12\n\n\x06REMOVE\x10\x03\x12\x0e\n\nDICTIONARY\x10\x04\"\x1c\n\x0eWorkerResponse\x12\n\n\x02id\x18\x02 \x01(\x05\"\x18\n\x08\x41nnounce\x12\x0c\n\x04port\x18\x01 \x02(\x05\"-\n\x0c\x44\x61taPrologue\x12\n\n\x02id\x18\x01 \x02(\x05\x12\x11\n\tdata_size\x18\x03 \x01(\x04\"%\n\x04\x44\x61ta\x12\x0f\n\x07type_id\x18\x01 \x02(\x05\x12\x0c\n\x04size\x18\x02 \x01(\x04\"\"\n\x04Info\x12\n\n\x02id\x18\x01 \x02(\x05\x12\x0e\n\x06worker\x18\x02 \x02(\t\"\x9b\x01\n\rClientMessage\x12*\n\x04type\x18\x01 \x02(\x0e\x32\x1c.loomcomm.ClientMessage.Type\x12$\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x16.loomcomm.DataPrologue\x12\x1c\n\x04info\x18\x03 \x01(\x0b\x32\x0e.loomcomm.Info\"\x1a\n\x04Type\x12\x08\n\x04\x44\x41TA\x10\x01\x12\x08\n\x04INFO\x10\x02\x42\x02H\x03')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -82,11 +82,15 @@ _WORKERCOMMAND_TYPE = _descriptor.EnumDescriptor(
name='REMOVE', index=2, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DICTIONARY', index=3, number=4,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=429,
serialized_end=467,
serialized_start=446,
serialized_end=500,
)
_sym_db.RegisterEnumDescriptor(_WORKERCOMMAND_TYPE)
......@@ -107,8 +111,8 @@ _CLIENTMESSAGE_TYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=777,
serialized_end=803,
serialized_start=810,
serialized_end=836,
)
_sym_db.RegisterEnumDescriptor(_CLIENTMESSAGE_TYPE)
......@@ -259,6 +263,13 @@ _WORKERCOMMAND = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='symbols', full_name='loomcomm.WorkerCommand.symbols', index=7,
number=100, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
......@@ -272,7 +283,7 @@ _WORKERCOMMAND = _descriptor.Descriptor(
oneofs=[
],
serialized_start=259,
serialized_end=467,
serialized_end=500,
)
......@@ -301,8 +312,8 @@ _WORKERRESPONSE = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=469,
serialized_end=497,
serialized_start=502,
serialized_end=530,
)
......@@ -331,8 +342,8 @@ _ANNOUNCE = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=499,
serialized_end=523,
serialized_start=532,
serialized_end=556,
)
......@@ -368,8 +379,8 @@ _DATAPROLOGUE = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=525,
serialized_end=570,
serialized_start=558,
serialized_end=603,
)
......@@ -405,8 +416,8 @@ _DATA = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=572,
serialized_end=609,
serialized_start=605,
serialized_end=642,
)
......@@ -442,8 +453,8 @@ _INFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=611,
serialized_end=645,
serialized_start=644,
serialized_end=678,
)
......@@ -487,8 +498,8 @@ _CLIENTMESSAGE = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=648,
serialized_end=803,
serialized_start=681,
serialized_end=836,
)
_REGISTER.fields_by_name['type'].enum_type = _REGISTER_TYPE
......
......@@ -11,6 +11,8 @@ add_library(libloom
taskinstance.cpp
taskinstance.h
taskfactory.h
dictionary.cpp
dictionary.h
databuilder.cpp
databuilder.h
data.cpp
......
#include "dictionary.h"
#include <assert.h>
using namespace loom;
Dictionary::Dictionary()
{
}
Id Dictionary::lookup_symbol(const std::string &symbol)
{
auto i = symbol_to_id.find(symbol);
assert(i != symbol_to_id.end());
assert(i->second != -1);
return i->second;
}
Id Dictionary::find_or_create(const std::string &symbol)
{
auto i = symbol_to_id.find(symbol);
if (i == symbol_to_id.end()) {
int new_id = symbol_to_id.size();
symbol_to_id[symbol] = new_id;
return new_id;
} else {
return i->second;
}
}
std::vector<std::string> Dictionary::get_all_symbols() const
{
std::vector<std::string> symbols;
int size = symbol_to_id.size();
symbols.resize(size);
for (auto &i : symbol_to_id) {
assert(i.second >= 0 && i.second < size);
symbols[i.second] = i.first;
}
return symbols;
}
#ifndef LIBLOOM_DICTIONARY_H
#define LIBLOOM_DICTIONARY_H
#include "types.h"
#include <unordered_map>
#include <string>
#include <vector>
namespace loom {
class Dictionary {
public:
Dictionary();
loom::Id lookup_symbol(const std::string &symbol);
loom::Id find_or_create(const std::string &symbol);
std::vector<std::string> get_all_symbols() const;
private:
std::unordered_map<std::string, loom::Id> symbol_to_id;
};
}
#endif // LIBLOOM_DICTIONARY_H
......@@ -640,6 +640,7 @@ bool WorkerCommand_Type_IsValid(int value) {
case 1:
case 2:
case 3:
case 4:
return true;
default:
return false;
......@@ -650,6 +651,7 @@ bool WorkerCommand_Type_IsValid(int value) {
const WorkerCommand_Type WorkerCommand::TASK;
const WorkerCommand_Type WorkerCommand::SEND;
const WorkerCommand_Type WorkerCommand::REMOVE;
const WorkerCommand_Type WorkerCommand::DICTIONARY;
const WorkerCommand_Type WorkerCommand::Type_MIN;
const WorkerCommand_Type WorkerCommand::Type_MAX;
const int WorkerCommand::Type_ARRAYSIZE;
......@@ -662,6 +664,7 @@ const int WorkerCommand::kTaskConfigFieldNumber;
const int WorkerCommand::kTaskInputsFieldNumber;
const int WorkerCommand::kAddressFieldNumber;
const int WorkerCommand::kWithSizeFieldNumber;
const int WorkerCommand::kSymbolsFieldNumber;
#endif // !_MSC_VER
WorkerCommand::WorkerCommand()
......@@ -763,6 +766,7 @@ void WorkerCommand::Clear() {
#undef ZR_
task_inputs_.Clear();
symbols_.Clear();
::memset(_has_bits_, 0, sizeof(_has_bits_));
mutable_unknown_fields()->clear();
}
......@@ -777,7 +781,7 @@ bool WorkerCommand::MergePartialFromCodedStream(
&unknown_fields_string);
// @@protoc_insertion_point(parse_start:loomcomm.WorkerCommand)
for (;;) {
::std::pair< ::google::protobuf::uint32, bool> p = input->ReadTagWithCutoff(127);
::std::pair< ::google::protobuf::uint32, bool> p = input->ReadTagWithCutoff(16383);
tag = p.first;
if (!p.second) goto handle_unusual;
switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) {
......@@ -887,6 +891,20 @@ bool WorkerCommand::MergePartialFromCodedStream(
} else {
goto handle_unusual;
}
if (input->ExpectTag(802)) goto parse_symbols;
break;
}
// repeated string symbols = 100;
case 100: {
if (tag == 802) {
parse_symbols:
DO_(::google::protobuf::internal::WireFormatLite::ReadString(
input, this->add_symbols()));
} else {
goto handle_unusual;
}
if (input->ExpectTag(802)) goto parse_symbols;
if (input->ExpectAtEnd()) goto success;
break;
}
......@@ -955,6 +973,12 @@ void WorkerCommand::SerializeWithCachedSizes(
::google::protobuf::internal::WireFormatLite::WriteBool(11, this->with_size(), output);
}
// repeated string symbols = 100;
for (int i = 0; i < this->symbols_size(); i++) {
::google::protobuf::internal::WireFormatLite::WriteString(
100, this->symbols(i), output);
}
output->WriteRaw(unknown_fields().data(),
unknown_fields().size());
// @@protoc_insertion_point(serialize_end:loomcomm.WorkerCommand)
......@@ -1014,6 +1038,13 @@ int WorkerCommand::ByteSize() const {
total_size += 1 * this->task_inputs_size() + data_size;
}
// repeated string symbols = 100;
total_size += 2 * this->symbols_size();
for (int i = 0; i < this->symbols_size(); i++) {
total_size += ::google::protobuf::internal::WireFormatLite::StringSize(
this->symbols(i));
}
total_size += unknown_fields().size();
GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN();
......@@ -1030,6 +1061,7 @@ void WorkerCommand::CheckTypeAndMergeFrom(
void WorkerCommand::MergeFrom(const WorkerCommand& from) {
GOOGLE_CHECK_NE(&from, this);
task_inputs_.MergeFrom(from.task_inputs_);
symbols_.MergeFrom(from.symbols_);
if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) {
if (from.has_type()) {
set_type(from.type());
......@@ -1074,6 +1106,7 @@ void WorkerCommand::Swap(WorkerCommand* other) {
task_inputs_.Swap(&other->task_inputs_);
std::swap(address_, other->address_);
std::swap(with_size_, other->with_size_);
symbols_.Swap(&other->symbols_);
std::swap(_has_bits_[0], other->_has_bits_[0]);
_unknown_fields_.swap(other->_unknown_fields_);
std::swap(_cached_size_, other->_cached_size_);
......
......@@ -62,11 +62,12 @@ const int ServerMessage_Type_Type_ARRAYSIZE = ServerMessage_Type_Type_MAX + 1;
enum WorkerCommand_Type {
WorkerCommand_Type_TASK = 1,
WorkerCommand_Type_SEND = 2,
WorkerCommand_Type_REMOVE = 3
WorkerCommand_Type_REMOVE = 3,
WorkerCommand_Type_DICTIONARY = 4
};
bool WorkerCommand_Type_IsValid(int value);
const WorkerCommand_Type WorkerCommand_Type_Type_MIN = WorkerCommand_Type_TASK;
const WorkerCommand_Type WorkerCommand_Type_Type_MAX = WorkerCommand_Type_REMOVE;
const WorkerCommand_Type WorkerCommand_Type_Type_MAX = WorkerCommand_Type_DICTIONARY;
const int WorkerCommand_Type_Type_ARRAYSIZE = WorkerCommand_Type_Type_MAX + 1;
enum ClientMessage_Type {
......@@ -397,6 +398,7 @@ class WorkerCommand : public ::google::protobuf::MessageLite {
static const Type TASK = WorkerCommand_Type_TASK;
static const Type SEND = WorkerCommand_Type_SEND;
static const Type REMOVE = WorkerCommand_Type_REMOVE;
static const Type DICTIONARY = WorkerCommand_Type_DICTIONARY;
static inline bool Type_IsValid(int value) {
return WorkerCommand_Type_IsValid(value);
}
......@@ -473,6 +475,22 @@ class WorkerCommand : public ::google::protobuf::MessageLite {
inline bool with_size() const;
inline void set_with_size(bool value);
// repeated string symbols = 100;
inline int symbols_size() const;
inline void clear_symbols();
static const int kSymbolsFieldNumber = 100;
inline const ::std::string& symbols(int index) const;
inline ::std::string* mutable_symbols(int index);
inline void set_symbols(int index, const ::std::string& value);
inline void set_symbols(int index, const char* value);
inline void set_symbols(int index, const char* value, size_t size);
inline ::std::string* add_symbols();
inline void add_symbols(const ::std::string& value);
inline void add_symbols(const char* value);
inline void add_symbols(const char* value, size_t size);
inline const ::google::protobuf::RepeatedPtrField< ::std::string>& symbols() const;
inline ::google::protobuf::RepeatedPtrField< ::std::string>* mutable_symbols();
// @@protoc_insertion_point(class_scope:loomcomm.WorkerCommand)
private:
inline void set_has_type();
......@@ -499,6 +517,7 @@ class WorkerCommand : public ::google::protobuf::MessageLite {
::google::protobuf::int32 task_type_;
bool with_size_;
::std::string* address_;
::google::protobuf::RepeatedPtrField< ::std::string> symbols_;
#ifdef GOOGLE_PROTOBUF_NO_STATIC_INITIALIZER
friend void protobuf_AddDesc_loomcomm_2eproto_impl();
#else
......@@ -1597,6 +1616,60 @@ inline void WorkerCommand::set_with_size(bool value) {
// @@protoc_insertion_point(field_set:loomcomm.WorkerCommand.with_size)
}
// repeated string symbols = 100;
inline int WorkerCommand::symbols_size() const {
return symbols_.size();
}
inline void WorkerCommand::clear_symbols() {
symbols_.Clear();
}
inline const ::std::string& WorkerCommand::symbols(int index) const {
// @@protoc_insertion_point(field_get:loomcomm.WorkerCommand.symbols)
return symbols_.Get(index);
}
inline ::std::string* WorkerCommand::mutable_symbols(int index) {
// @@protoc_insertion_point(field_mutable:loomcomm.WorkerCommand.symbols)
return symbols_.Mutable(index);
}
inline void WorkerCommand::set_symbols(int index, const ::std::string& value) {
// @@protoc_insertion_point(field_set:loomcomm.WorkerCommand.symbols)
symbols_.Mutable(index)->assign(value);
}
inline void WorkerCommand::set_symbols(int index, const char* value) {
symbols_.Mutable(index)->assign(value);
// @@protoc_insertion_point(field_set_char:loomcomm.WorkerCommand.symbols)
}
inline void WorkerCommand::set_symbols(int index, const char* value, size_t size) {
symbols_.Mutable(index)->assign(
reinterpret_cast<const char*>(value), size);
// @@protoc_insertion_point(field_set_pointer:loomcomm.WorkerCommand.symbols)
}
inline ::std::string* WorkerCommand::add_symbols() {
return symbols_.Add();
}
inline void WorkerCommand::add_symbols(const ::std::string& value) {
symbols_.Add()->assign(value);
// @@protoc_insertion_point(field_add:loomcomm.WorkerCommand.symbols)
}
inline void WorkerCommand::add_symbols(const char* value) {
symbols_.Add()->assign(value);
// @@protoc_insertion_point(field_add_char:loomcomm.WorkerCommand.symbols)
}
inline void WorkerCommand::add_symbols(const char* value, size_t size) {
symbols_.Add()->assign(reinterpret_cast<const char*>(value), size);
// @@protoc_insertion_point(field_add_pointer:loomcomm.WorkerCommand.symbols)
}
inline const ::google::protobuf::RepeatedPtrField< ::std::string>&
WorkerCommand::symbols() const {
// @@protoc_insertion_point(field_list:loomcomm.WorkerCommand.symbols)
return symbols_;
}
inline ::google::protobuf::RepeatedPtrField< ::std::string>*
WorkerCommand::mutable_symbols() {
// @@protoc_insertion_point(field_mutable_list:loomcomm.WorkerCommand.symbols)
return &symbols_;
}
// -------------------------------------------------------------------
// WorkerResponse
......
......@@ -15,11 +15,11 @@ public:
Task(Id id, int task_type, const std::string &config)
: id(id), task_type(task_type), config(config) {}
int get_id() const {
Id get_id() const {
return id;
}
int get_task_type() const {
Id get_task_type() const {
return task_type;
}
......@@ -39,7 +39,7 @@ public:
protected:
Id id;
int task_type;
Id task_type;
std::vector<Id> inputs;
std::string config;
};
......
......@@ -134,7 +134,7 @@ void Worker::register_worker()
msg.set_port(get_listen_port());
msg.set_cpus(resource_cpus);
for (auto& factory : task_factories) {
for (auto& factory : unregistered_task_factories) {
msg.add_task_types(factory->get_name());
}
......@@ -154,9 +154,9 @@ void Worker::new_task(std::unique_ptr<Task> task)
void Worker::start_task(std::unique_ptr<Task> task)
{
llog->debug("Starting task id={} task_type={}", task->get_id(), task->get_task_type());
auto task_type = task->get_task_type();
assert(task_type >= 0 && task_type < (int) task_factories.size());
auto task_instance = task_factories[task_type]->make_instance(*this, std::move(task));
auto i = task_factories.find(task->get_task_type());
assert(i != task_factories.end());
auto task_instance = i->second->make_instance(*this, std::move(task));
TaskInstance *t = task_instance.get();
active_tasks.push_back(std::move(task_instance));
......@@ -306,6 +306,14 @@ std::unique_ptr<DataUnpacker> Worker::unpack(DataTypeId id)
return i->second->make_unpacker();
}
void Worker::on_dictionary_updated()
{
for (auto &f : unregistered_task_factories) {
loom::Id id = dictionary.lookup_symbol(f->get_name());
task_factories[id] = std::move(f);
}
}
void Worker::check_waiting_tasks()
{
bool something_new = false;
......@@ -412,6 +420,16 @@ void ServerConnection::on_message(const char *data, size_t size)
assert(worker.send_data(msg.address(), msg.id(), with_size));
break;
}
case loomcomm::WorkerCommand_Type_DICTIONARY: {
auto count = msg.symbols_size();
llog->debug("New dictionary ({} symbols)", count);
Dictionary &dictionary = worker.get_dictionary();
for (int i = 0; i < count; i++) {
dictionary.find_or_create(msg.symbols(i));
}
worker.on_dictionary_updated();
} break;
default:
llog->critical("Invalid message");
exit(1);
......
......@@ -5,6 +5,7 @@
#include "taskinstance.h"
#include "unpacking.h"
#include "taskfactory.h"
#include "dictionary.h"
#include <uv.h>
......@@ -78,8 +79,8 @@ public:
}
void add_task_factory(std::unique_ptr<TaskFactory> factory)
{
task_factories.push_back(std::move(factory));
{
unregistered_task_factories.push_back(std::move(factory));
}
InterConnection &get_connection(const std::string &address);
......@@ -116,6 +117,12 @@ public:
void add_unpacker(DataTypeId type_id, std::unique_ptr<UnpackFactory> factory);
std::unique_ptr<DataUnpacker> unpack(DataTypeId id);
Dictionary& get_dictionary() {
return dictionary;
}
void on_dictionary_updated();
private:
void register_worker();
void start_listen();
......@@ -131,7 +138,7 @@ private:
std::vector<std::unique_ptr<TaskInstance>> active_tasks;
std::vector<std::unique_ptr<Task>> ready_tasks;
std::vector<std::unique_ptr<Task>> waiting_tasks;
std::vector<std::unique_ptr<TaskFactory>> task_factories;
std::unordered_map<Id, std::unique_ptr<TaskFactory>> task_factories;
std::unordered_map<int, std::shared_ptr<Data>> public_data;
std::string work_dir;
......@@ -142,12 +149,16 @@ private:
std::unordered_map<std::string, std::unique_ptr<InterConnection>> connections;
std::vector<std::unique_ptr<InterConnection>> nonregistered_connections;
Dictionary dictionary;
std::string server_address;
int server_port;
uv_tcp_t listen_socket;
int listen_port;
std::vector<std::unique_ptr<TaskFactory>> unregistered_task_factories;
static void _on_new_connection(uv_stream_t *stream, int status);
static void _on_getaddrinfo(uv_getaddrinfo_t* handle, int status, struct addrinfo* response);
};
......
......@@ -30,7 +30,7 @@ message WorkerCommand {
TASK = 1;
SEND = 2;
REMOVE = 3;
DICTIONARY = 4;
}
required Type type = 1;
......@@ -45,6 +45,9 @@ message WorkerCommand {
// SEND
optional string address = 10;
optional bool with_size = 11;
// DICTIONARY
repeated string symbols = 100;
}
message WorkerResponse {
......
......@@ -37,11 +37,12 @@ void FreshConnection::on_message(const char *buffer, size_t size)
std::stringstream address;
address << this->connection->get_peername() << ":" << msg.port();
std::vector<std::string> task_types;
std::vector<int> task_types;
task_types.reserve(msg.task_types_size());
Dictionary &dictionary = server.get_dictionary();
for (int i = 0; i < msg.task_types_size(); i++) {
task_types.push_back(msg.task_types(i));
task_types.push_back(dictionary.find_or_create(msg.task_types(i)));
}
auto wconn = std::make_unique<WorkerConnection>(server,
......
......@@ -7,6 +7,8 @@
#include "taskmanager.h"
#include "dummyworker.h"
#include "libloom/dictionary.h"
#include <vector>
......@@ -50,6 +52,10 @@ public:
void on_task_finished(TaskNode &task);
loom::Dictionary& get_dictionary() {
return dictionary;
}
private:
void start_listen();
......@@ -67,6 +73,8 @@ private:
TaskManager task_manager;
DummyWorker dummy_worker;
loom::Dictionary dictionary;
static void _on_new_connection(uv_stream_t *stream, int status);
};
......
......@@ -21,8 +21,9 @@ void TaskManager::add_plan(const loomplan::Plan &plan, bool distribute)
int tt_size = plan.task_types_size();
int type_task_translation[tt_size];
Dictionary &dictionary = server.get_dictionary();
for (int i = 0; i < tt_size; i++) {
type_task_translation[i] = translate_task_type(plan.task_types(i));
type_task_translation[i] = dictionary.find_or_create(plan.task_types(i));
}
auto task_size = plan.tasks_size();
......@@ -191,14 +192,3 @@ void TaskManager::distribute_work(TaskNode::Vector &tasks)
}
}
}
int TaskManager::_translate(std::vector<std::string> &table, const std::string &item)
{
auto it = std::find(table.begin(), table.end(), item);
if (it == table.end()) {
int result = table.size();
table.push_back(item);
return result;
}
return std::distance(table.begin(), it);
}
......@@ -33,10 +33,6 @@ public:
void add_plan(const loomplan::Plan &plan, bool distribute=true);
loom::TaskId translate_task_type(const std::string &item) {
return _translate(task_types, item);
}
void on_task_finished(TaskNode &task);
WorkDistribution compute_distribution(TaskNode::Vector &tasks);
......@@ -52,8 +48,6 @@ private:
std::vector<std::string> task_types;
void distribute_work(TaskNode::Vector &tasks);
static int _translate(std::vector<std::string> &table, const std::string &item);
};
......
......@@ -11,24 +11,38 @@ using namespace loom;
WorkerConnection::WorkerConnection(Server &server,
std::unique_ptr<Connection> connection,
const std::string& address,
const std::vector<std::string> &task_types,
const std::vector<loom::Id> &task_types,
int resource_cpus)
: server(server),
connection(std::move(connection)),
resource_cpus(resource_cpus),
address(address)
address(address),