Commit f424553a authored by Stanislav Bohm's avatar Stanislav Bohm

ENH: Symbols for data types

parent c1055fc6
......@@ -26,6 +26,13 @@ class Client(object):
def __init__(self, address, port, info=False):
self.server_address = address
self.server_port = port
self.dictionary_symbols = None
self.dictionary_map = None
self.array_id = None
self.rawdata_id = None
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((address, port))
self.connection = Connection(s)
......@@ -69,6 +76,13 @@ class Client(object):
self.add_info(cmsg.info)
elif cmsg.type == ClientMessage.ERROR:
self.process_error(cmsg)
elif cmsg.type == ClientMessage.DICTIONARY:
self.dictionary_symbols = cmsg.symbols
self.dictionary_map = {}
for i, s in enumerate(self.dictionary_symbols):
self.dictionary_map[s] = i
self.array_id = self.dictionary_map["loom/array"]
self.rawdata_id = self.dictionary_map["loom/data"]
if single_result:
return data[results.id]
......@@ -87,9 +101,9 @@ class Client(object):
msg_data = Data()
msg_data.ParseFromString(self.connection.receive_message())
type_id = msg_data.type_id
if type_id == 300: # Data
if type_id == self.rawdata_id:
return self.connection.read_data(msg_data.size)
if type_id == 400: # Array
if type_id == self.array_id:
return [self._receive_data() for i in xrange(msg_data.length)]
assert 0
......
This diff is collapsed.
......@@ -26,7 +26,7 @@ std::shared_ptr<Data> Data::get_slice(size_t from, size_t to)
void Data::serialize(Worker &worker, SendBuffer &buffer, std::shared_ptr<Data> &data_ptr)
{
loomcomm::Data msg;
msg.set_type_id(get_type_id());
msg.set_type_id(worker.get_dictionary().lookup_symbol(get_type_name()));
msg.set_size(get_size());
auto length = get_length();
if (length) {
......
......@@ -20,7 +20,7 @@ class Data
public:
virtual ~Data();
virtual int get_type_id() = 0;
virtual std::string get_type_name() const = 0;
virtual size_t get_size() = 0;
virtual std::string get_info() = 0;
virtual size_t get_length();
......
......@@ -76,6 +76,11 @@ void Array::serialize_data(Worker &worker, SendBuffer &buffer, std::shared_ptr<D
}
}
std::string Array::get_type_name() const
{
return ArrayUnpacker::get_type_name();
}
ArrayUnpacker::~ArrayUnpacker()
{
......
......@@ -6,16 +6,10 @@
namespace loom {
class Array : public Data {
public:
static const int TYPE_ID = 400;
public:
Array(size_t length, std::unique_ptr<std::shared_ptr<Data>[]> items);
~Array();
int get_type_id() {
return TYPE_ID;
}
size_t get_length() {
return length;
}
......@@ -27,7 +21,8 @@ public:
std::shared_ptr<Data>& get_ref_at_index(size_t index);
void serialize_data(Worker &worker, SendBuffer &buffer, std::shared_ptr<Data> &data_ptr);
void serialize_data(Worker &worker, SendBuffer &buffer, std::shared_ptr<Data> &data_ptr);
std::string get_type_name() const;
private:
size_t length;
......@@ -45,6 +40,10 @@ public:
void on_data_chunk(const char *data, size_t size);
bool on_data_finish(Connection &connection);
static const char* get_type_name() {
return "loom/array";
}
protected:
void finish();
......
......@@ -11,6 +11,11 @@
using namespace loom;
std::string ExternFile::get_type_name() const
{
return "loom/file";
}
ExternFile::ExternFile(const std::string &filename)
: data(nullptr), filename(filename)
{
......
......@@ -9,15 +9,10 @@ namespace loom {
class ExternFile : public Data {
public:
static const int TYPE_ID = 301;
std::string get_type_name() const;
ExternFile(const std::string &filename);
~ExternFile();
int get_type_id() {
return TYPE_ID;
}
size_t get_size() {
return size;
}
......@@ -35,6 +30,7 @@ public:
std::string get_filename() const;
private:
void open();
......
......@@ -19,6 +19,11 @@ Index::~Index()
llog->debug("Disposing index");
}
std::string Index::get_type_name() const
{
return IndexUnpacker::get_type_name();
}
size_t Index::get_length()
{
return length;
......
......@@ -13,8 +13,6 @@ class Worker;
class Index : public Data {
public:
static const int TYPE_ID = 500;
Index(Worker &worker,
std::shared_ptr<Data> &data,
size_t length,
......@@ -22,10 +20,7 @@ public:
~Index();
int get_type_id() {
return TYPE_ID;
}
std::string get_type_name() const;
size_t get_length();
size_t get_size();
std::string get_info();
......@@ -34,6 +29,7 @@ public:
void serialize_data(Worker &worker, SendBuffer &buffer, std::shared_ptr<Data> &data_ptr);
private:
Worker &worker;
std::shared_ptr<Data> data;
......@@ -52,6 +48,10 @@ public:
void on_data_chunk(const char *data, size_t size);
bool on_data_finish(Connection &connection);
static const char* get_type_name() {
return "loom/index";
}
protected:
void finish_data();
......
......@@ -37,7 +37,10 @@ RawData::~RawData()
}
}
std::string RawData::get_type_name() const
{
return RawDataUnpacker::get_type_name();
}
/*char* RawData::init_memonly(size_t size)
{
......@@ -103,7 +106,7 @@ std::string RawData::get_filename() const
void RawData::open(Worker &worker)
{
if (size == 0) {
return;
return;
}
assert(!filename.empty());
int fd = ::open(filename.c_str(), O_RDONLY, S_IRUSR | S_IWUSR);
......@@ -135,7 +138,7 @@ void RawData::map(int fd, bool write)
std::string RawData::get_info()
{
return "RawData";
return "RawData";
}
void RawData::serialize_data(Worker &worker, SendBuffer &buffer, std::shared_ptr<Data> &data_ptr)
......
......@@ -6,15 +6,12 @@
namespace loom {
class RawData : public Data {
public:
static const int TYPE_ID = 300;
public:
RawData();
~RawData();
int get_type_id() {
return TYPE_ID;
}
std::string get_type_name() const;
size_t get_size() {
return size;
......@@ -63,6 +60,11 @@ public:
bool init(Worker &worker, Connection &connection, const loomcomm::Data &msg);
void on_data_chunk(const char *data, size_t size);
bool on_data_finish(Connection &connection);
static const char* get_type_name() {
return "loom/data";
}
protected:
char *pointer = nullptr;
};
......
......@@ -102,6 +102,7 @@ const int Register::kProtocolVersionFieldNumber;
const int Register::kTypeFieldNumber;
const int Register::kPortFieldNumber;
const int Register::kTaskTypesFieldNumber;
const int Register::kDataTypesFieldNumber;
const int Register::kCpusFieldNumber;
const int Register::kInfoFieldNumber;
#endif // !_MSC_VER
......@@ -178,16 +179,18 @@ void Register::Clear() {
::memset(&first, 0, n); \
} while (0)
if (_has_bits_[0 / 32] & 55) {
ZR_(port_, info_);
if (_has_bits_[0 / 32] & 103) {
ZR_(port_, cpus_);
protocol_version_ = 0;
type_ = 1;
info_ = false;
}
#undef OFFSET_OF_FIELD_
#undef ZR_
task_types_.Clear();
data_types_.Clear();
::memset(_has_bits_, 0, sizeof(_has_bits_));
mutable_unknown_fields()->clear();
}
......@@ -266,13 +269,27 @@ bool Register::MergePartialFromCodedStream(
goto handle_unusual;
}
if (input->ExpectTag(34)) goto parse_task_types;
if (input->ExpectTag(40)) goto parse_cpus;
if (input->ExpectTag(42)) goto parse_data_types;
break;
}
// optional int32 cpus = 5;
// repeated string data_types = 5;
case 5: {
if (tag == 40) {
if (tag == 42) {
parse_data_types:
DO_(::google::protobuf::internal::WireFormatLite::ReadString(
input, this->add_data_types()));
} else {
goto handle_unusual;
}
if (input->ExpectTag(42)) goto parse_data_types;
if (input->ExpectTag(48)) goto parse_cpus;
break;
}
// optional int32 cpus = 6;
case 6: {
if (tag == 48) {
parse_cpus:
DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive<
::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>(
......@@ -347,9 +364,15 @@ void Register::SerializeWithCachedSizes(
4, this->task_types(i), output);
}
// optional int32 cpus = 5;
// repeated string data_types = 5;
for (int i = 0; i < this->data_types_size(); i++) {
::google::protobuf::internal::WireFormatLite::WriteString(
5, this->data_types(i), output);
}
// optional int32 cpus = 6;
if (has_cpus()) {
::google::protobuf::internal::WireFormatLite::WriteInt32(5, this->cpus(), output);
::google::protobuf::internal::WireFormatLite::WriteInt32(6, this->cpus(), output);
}
// optional bool info = 10;
......@@ -386,7 +409,7 @@ int Register::ByteSize() const {
this->port());
}
// optional int32 cpus = 5;
// optional int32 cpus = 6;
if (has_cpus()) {
total_size += 1 +
::google::protobuf::internal::WireFormatLite::Int32Size(
......@@ -406,6 +429,13 @@ int Register::ByteSize() const {
this->task_types(i));
}
// repeated string data_types = 5;
total_size += 1 * this->data_types_size();
for (int i = 0; i < this->data_types_size(); i++) {
total_size += ::google::protobuf::internal::WireFormatLite::StringSize(
this->data_types(i));
}
total_size += unknown_fields().size();
GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN();
......@@ -422,6 +452,7 @@ void Register::CheckTypeAndMergeFrom(
void Register::MergeFrom(const Register& from) {
GOOGLE_CHECK_NE(&from, this);
task_types_.MergeFrom(from.task_types_);
data_types_.MergeFrom(from.data_types_);
if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) {
if (from.has_protocol_version()) {
set_protocol_version(from.protocol_version());
......@@ -460,6 +491,7 @@ void Register::Swap(Register* other) {
std::swap(type_, other->type_);
std::swap(port_, other->port_);
task_types_.Swap(&other->task_types_);
data_types_.Swap(&other->data_types_);
std::swap(cpus_, other->cpus_);
std::swap(info_, other->info_);
std::swap(_has_bits_[0], other->_has_bits_[0]);
......@@ -2781,6 +2813,7 @@ bool ClientMessage_Type_IsValid(int value) {
case 1:
case 2:
case 3:
case 4:
return true;
default:
return false;
......@@ -2791,6 +2824,7 @@ bool ClientMessage_Type_IsValid(int value) {
const ClientMessage_Type ClientMessage::DATA;
const ClientMessage_Type ClientMessage::INFO;
const ClientMessage_Type ClientMessage::ERROR;
const ClientMessage_Type ClientMessage::DICTIONARY;
const ClientMessage_Type ClientMessage::Type_MIN;
const ClientMessage_Type ClientMessage::Type_MAX;
const int ClientMessage::Type_ARRAYSIZE;
......@@ -2800,6 +2834,7 @@ const int ClientMessage::kTypeFieldNumber;
const int ClientMessage::kDataFieldNumber;
const int ClientMessage::kInfoFieldNumber;
const int ClientMessage::kErrorFieldNumber;
const int ClientMessage::kSymbolsFieldNumber;
#endif // !_MSC_VER
ClientMessage::ClientMessage()
......@@ -2837,6 +2872,7 @@ ClientMessage::ClientMessage(const ClientMessage& from)
}
void ClientMessage::SharedCtor() {
::google::protobuf::internal::GetEmptyString();
_cached_size_ = 0;
type_ = 1;
data_ = NULL;
......@@ -2895,6 +2931,7 @@ void ClientMessage::Clear() {
if (error_ != NULL) error_->::loomcomm::Error::Clear();
}
}
symbols_.Clear();
::memset(_has_bits_, 0, sizeof(_has_bits_));
mutable_unknown_fields()->clear();
}
......@@ -2968,6 +3005,20 @@ bool ClientMessage::MergePartialFromCodedStream(
} else {
goto handle_unusual;
}
if (input->ExpectTag(42)) goto parse_symbols;
break;
}
// repeated string symbols = 5;
case 5: {
if (tag == 42) {
parse_symbols:
DO_(::google::protobuf::internal::WireFormatLite::ReadString(
input, this->add_symbols()));
} else {
goto handle_unusual;
}
if (input->ExpectTag(42)) goto parse_symbols;
if (input->ExpectAtEnd()) goto success;
break;
}
......@@ -3021,6 +3072,12 @@ void ClientMessage::SerializeWithCachedSizes(
4, this->error(), output);
}
// repeated string symbols = 5;
for (int i = 0; i < this->symbols_size(); i++) {
::google::protobuf::internal::WireFormatLite::WriteString(
5, this->symbols(i), output);
}
output->WriteRaw(unknown_fields().data(),
unknown_fields().size());
// @@protoc_insertion_point(serialize_end:loomcomm.ClientMessage)
......@@ -3058,6 +3115,13 @@ int ClientMessage::ByteSize() const {
}
}
// repeated string symbols = 5;
total_size += 1 * this->symbols_size();
for (int i = 0; i < this->symbols_size(); i++) {
total_size += ::google::protobuf::internal::WireFormatLite::StringSize(
this->symbols(i));
}
total_size += unknown_fields().size();
GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN();
......@@ -3073,6 +3137,7 @@ void ClientMessage::CheckTypeAndMergeFrom(
void ClientMessage::MergeFrom(const ClientMessage& from) {
GOOGLE_CHECK_NE(&from, this);
symbols_.MergeFrom(from.symbols_);
if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) {
if (from.has_type()) {
set_type(from.type());
......@@ -3117,6 +3182,7 @@ void ClientMessage::Swap(ClientMessage* other) {
std::swap(data_, other->data_);
std::swap(info_, other->info_);
std::swap(error_, other->error_);
symbols_.Swap(&other->symbols_);
std::swap(_has_bits_[0], other->_has_bits_[0]);
_unknown_fields_.swap(other->_unknown_fields_);
std::swap(_cached_size_, other->_cached_size_);
......
This diff is collapsed.
......@@ -13,6 +13,7 @@ class UnpackFactory
public:
virtual ~UnpackFactory();
virtual std::unique_ptr<DataUnpacker> make_unpacker() = 0;
virtual const char* get_type_name() const = 0;
};
template<typename T> class SimpleUnpackFactory : public UnpackFactory
......@@ -21,6 +22,10 @@ public:
std::unique_ptr<DataUnpacker> make_unpacker() {
return std::make_unique<T>();
}
virtual const char* get_type_name() const {
return T::get_type_name();
}
};
}
......
......@@ -72,14 +72,11 @@ Worker::Worker(uv_loop_t *loop,
llog->info("Using '{}' as working directory", work_dir);
}
add_unpacker(RawData::TYPE_ID,
std::make_unique<SimpleUnpackFactory<RawDataUnpacker>>());
add_unpacker(std::make_unique<SimpleUnpackFactory<RawDataUnpacker>>());
add_unpacker(Array::TYPE_ID,
std::make_unique<SimpleUnpackFactory<ArrayUnpacker>>());
add_unpacker(std::make_unique<SimpleUnpackFactory<ArrayUnpacker>>());
add_unpacker(Index::TYPE_ID,
std::make_unique<SimpleUnpackFactory<IndexUnpacker>>());
add_unpacker(std::make_unique<SimpleUnpackFactory<IndexUnpacker>>());
resource_cpus = 1;
}
......@@ -171,6 +168,10 @@ void Worker::register_worker()
msg.add_task_types(factory->get_name());
}
for (auto& factory : unregistered_unpack_factories) {
msg.add_data_types(factory->get_type_name());
}
server_conn.send_message(msg);
}
......@@ -189,7 +190,7 @@ void Worker::start_task(std::unique_ptr<Task> task)
llog->debug("Starting task id={} task_type={}", task->get_id(), task->get_task_type());
auto i = task_factories.find(task->get_task_type());
if (unlikely(i == task_factories.end())) {
llog->critical("Task with unknown type received");
llog->critical("Task with unknown type {} received", task->get_task_type());
exit(1);
}
auto task_instance = i->second->make_instance(*this, std::move(task));
......@@ -328,11 +329,9 @@ void Worker::set_cpus(int value)
llog->info("Number of CPUs for worker: {}", value);
}
void Worker::add_unpacker(DataTypeId type_id, std::unique_ptr<UnpackFactory> factory)
void Worker::add_unpacker(std::unique_ptr<UnpackFactory> factory)
{
auto &f = unpack_factories[type_id];
assert(f.get() == nullptr);
f = std::move(factory);
unregistered_unpack_factories.push_back(std::move(factory));
}
std::unique_ptr<DataUnpacker> Worker::unpack(DataTypeId id)
......@@ -346,8 +345,17 @@ void Worker::on_dictionary_updated()
{
for (auto &f : unregistered_task_factories) {
loom::Id id = dictionary.lookup_symbol(f->get_name());
llog->debug("Registering task_factory: {} = {}", f->get_name(), id);
task_factories[id] = std::move(f);
}
unregistered_task_factories.clear();
for (auto &f : unregistered_unpack_factories) {
loom::Id id = dictionary.lookup_symbol(f->get_type_name());
llog->debug("Registering unpack_factory: {} = {}", f->get_type_name(), id);
unpack_factories[id] = std::move(f);
}
unregistered_unpack_factories.clear();
}
void Worker::check_waiting_tasks()
......
......@@ -122,7 +122,8 @@ public:
void check_ready_tasks();
void set_cpus(int value);
void add_unpacker(DataTypeId type_id, std::unique_ptr<UnpackFactory> factory);
void add_unpacker(std::unique_ptr<UnpackFactory> factory);
std::unique_ptr<DataUnpacker> unpack(DataTypeId id);
Dictionary& get_dictionary() {
......@@ -166,6 +167,7 @@ private:
int listen_port;
std::vector<std::unique_ptr<TaskFactory>> unregistered_task_factories;
std::vector<std::unique_ptr<UnpackFactory>> unregistered_unpack_factories;
static void _on_new_connection(uv_stream_t *stream, int status);
static void _on_getaddrinfo(uv_getaddrinfo_t* handle, int status, struct addrinfo* response);
......
......@@ -13,7 +13,8 @@ message Register {
// Worker
optional int32 port = 3;
repeated string task_types = 4;
optional int32 cpus = 5;
repeated string data_types = 5;
optional int32 cpus = 6;
// Client
optional bool info = 10;
......@@ -100,9 +101,11 @@ message ClientMessage {
DATA = 1;
INFO = 2;
ERROR = 3;
DICTIONARY = 4;
}
required Type type = 1;
optional DataPrologue data = 2;
optional Info info = 3;
optional Error error = 4;
repeated string symbols = 5;
}
......@@ -2,6 +2,7 @@
#include "server.h"
#include "libloom/loomplan.pb.h"
#include "libloom/loomcomm.pb.h"
#include "libloom/log.h"
using namespace loom;
......@@ -11,6 +12,20 @@ ClientConnection::ClientConnection(Server &server, std::unique_ptr<loom::Connect
{
this->connection->set_callback(this);
llog->info("Client {} connected", this->connection->get_peername());
// Send dictionary
loomcomm::ClientMessage cmsg;
cmsg.set_type(loomcomm::ClientMessage_Type_DICTIONARY);
std::vector<std::string> symbols = server.get_dictionary().get_all_symbols();
for (std::string &symbol : symbols) {
std::string *s = cmsg.add_symbols();
*s = symbol;
}
SendBuffer *send_buffer = new SendBuffer();
send_buffer->add(cmsg);
this->connection->send_buffer(send_buffer);
// End of send dictionary
}
ClientConnection::~ClientConnection()
......
......@@ -37,7 +37,7 @@ void FreshConnection::on_message(const char *buffer, size_t size)
std::stringstream address;
address << this->connection->get_peername() << ":" << msg.port();
std::vector<int> task_types;
std::vector<int> task_types, data_types;
task_types.reserve(msg.task_types_size());
Dictionary &dictionary = server.get_dictionary();
......@@ -45,10 +45,15 @@ void FreshConnection::on_message(const char *buffer, size_t size)
task_types.push_back(dictionary.find_or_create(msg.task_types(i)));
}
for (int i = 0; i < msg.data_types_size(); i++) {
data_types.push_back(dictionary.find_or_create(msg.data_types(i)));
}
auto wconn = std::make_unique<WorkerConnection>(server,
std::move(connection),
address.str(),
task_types,
data_types,
msg.cpus());
server.add_worker_connection(std::move(wconn));
......
......@@ -110,6 +110,20 @@ void Server::inform_about_task_error(Id id, WorkerConnection &wconn, const std::
exit(1);
}
void Server::send_dictionary(Connection &connection)
{
loomcomm::WorkerCommand msg;
msg.set_type(loomcomm::WorkerCommand_Type_DICTIONARY);
std::vector<std::string> symbols = dictionary.get_all_symbols();
for (std::string &symbol : symbols) {
std::string *s = msg.add_symbols();
*s = symbol;
}
SendBuffer *send_buffer = new SendBuffer();
send_buffer->add(msg);
connection.send_buffer(send_buffer);
}
void Server::start_listen()
{
struct sockaddr_in addr;
......
......@@ -11,7 +11,6 @@
#include <vector>
class Server {
public:
......@@ -65,6 +64,8 @@ public:
return id;
}
void send_dictionary(loom::Connection &connection);
private:
void start_listen();
......
......@@ -12,37 +12,24 @@ WorkerConnection::WorkerConnection(Server &server,
std::unique_ptr<Connection> connection,
const std::string& address,
const std::vector<loom::Id> &task_types,
const std::vector<loom::Id> &data_types,
int resource_cpus)
: server(server),
connection(std::move(connection)),
resource_cpus(resource_cpus),
address(address),
task_types(task_types)
task_types(task_types),
data_types(data_types)
{
llog->info("Worker {} connected (cpus={})", address, resource_cpus);
if (this->connection) {
this->connection->set_callback(this);
loomcomm::WorkerCommand msg;
msg.set_type(loomcomm::WorkerCommand_Type_DICTIONARY);
std::vector<std::string> symbols = server.get_dictionary().get_all_symbols();
for (std::string &symbol : symbols) {
std::string *s = msg.add_symbols();
*s = symbol;
}
SendBuffer *send_buffer = new SendBuffer();
send_buffer->add(msg);
this->connection->send_buffer(send_buffer);
server.send_dictionary(*this->connection);
}
if (unlikely(task_types.size() == 0)) {
llog->warn("No task_type has been registered by worker");
}
/*auto &manager = server.get_task_manager();
for (size_t i = 0; i < task_types.size(); i++) {
task_type_translates[i] = manager.translate_task_type(task_types[i]);<