Commit e12fe38c authored by Stanislav Bohm's avatar Stanislav Bohm

ENH: Scheduling improved

parent c510093c
......@@ -9,16 +9,22 @@ LOOM_PROTOCOL_VERSION = 1
class Client(object):
def __init__(self, address, port):
def __init__(self, address, port, info=False):
self.server_address = address
self.server_port = port
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((address, port))
self.connection = Connection(s)
if info:
self.info = []
else:
self.info = None
msg = Register()
msg.type = Register.REGISTER_CLIENT
msg.protocol_version = LOOM_PROTOCOL_VERSION
msg.info = info
self._send_message(msg)
def submit(self, plan, results):
......@@ -42,14 +48,21 @@ class Client(object):
msg = self.connection.receive_message()
cmsg = ClientMessage()
cmsg.ParseFromString(msg)
assert cmsg.type == ClientMessage.DATA
prologue = cmsg.data
data[prologue.id] = self._receive_data()
if cmsg.type == ClientMessage.DATA:
prologue = cmsg.data
data[prologue.id] = self._receive_data()
else:
assert cmsg.type == ClientMessage.INFO
self.add_info(cmsg.info)
if single_result:
return data[results.id]
else:
return [data[task.id] for task in results]
def add_info(self, info):
self.info.append((info.id, info.worker))
def _receive_data(self):
msg_data = Data()
msg_data.ParseFromString(self.connection.receive_message())
......
This diff is collapsed.
......@@ -176,10 +176,21 @@ class Plan(object):
task.set_message(t, task_types)
return msg
def write_dot(self, filename):
def write_dot(self, filename, info=None):
colors = ["red", "green", "blue", "orange", "violet"]
if info:
w = sorted(set(worker for id, worker in info))
workers = {}
for id, worker in info:
workers[id] = w.index(worker)
del w
else:
workers = None
graph = gv.Graph()
for task in self.tasks:
node = graph.node(task.id)
if workers:
node.color = colors[workers[task.id] % len(colors)]
node.label = "{}\n{}".format(str(task.id), task.task_type)
for inp in task.inputs:
graph.node(inp.id).add_arc(node)
......
......@@ -51,9 +51,10 @@ void Connection::close_and_discard_remaining_data()
}
void Connection::accept(uv_tcp_t *listen_socket)
{
{
UV_CHECK(uv_accept((uv_stream_t*) listen_socket, (uv_stream_t*) &socket));
uv_read_start((uv_stream_t *)&socket, _buf_alloc, _on_read);
state = ConnectionOpen;
}
void Connection::start_read()
......
......@@ -126,11 +126,6 @@ void InterConnection::send(Id id, std::shared_ptr<Data> &data, bool with_size)
}
}
void InterConnection::send(std::unique_ptr<SendBuffer> buffer)
{
}
std::string InterConnection::make_address(const std::string &host, int port)
{
std::stringstream s;
......
......@@ -19,8 +19,7 @@ public:
InterConnection(Worker &worker);
~InterConnection();
void send(Id id, std::shared_ptr<Data> &data, bool with_size);
void send(std::unique_ptr<SendBuffer> buffer);
void send(Id id, std::shared_ptr<Data> &data, bool with_size);
void accept(uv_tcp_t *listen_socket) {
connection.accept(listen_socket);
}
......
This diff is collapsed.
This diff is collapsed.
......@@ -132,6 +132,7 @@ void Worker::register_worker()
msg.set_protocol_version(PROTOCOL_VERSION);
msg.set_port(get_listen_port());
msg.set_cpus(resource_cpus);
for (auto& factory : task_factories) {
msg.add_task_types(factory->get_name());
......
......@@ -9,8 +9,14 @@ message Register {
}
required int32 protocol_version = 1;
required Type type = 2;
// Worker
optional int32 port = 3;
repeated string task_types = 4;
optional int32 cpus = 5;
// Client
optional bool info = 10;
}
message ServerMessage {
......@@ -57,7 +63,7 @@ message Data
optional uint64 size = 2;
}
message Feedback {
message Info {
required int32 id = 1;
required string worker = 2;
}
......@@ -65,9 +71,9 @@ message Feedback {
message ClientMessage {
enum Type {
DATA = 1;
FEEDBACK = 2;
INFO = 2;
}
required Type type = 1;
optional DataPrologue data = 2;
optional Feedback feedback = 3;
optional Info info = 3;
}
......@@ -6,8 +6,8 @@
using namespace loom;
ClientConnection::ClientConnection(Server &server, std::unique_ptr<loom::Connection> connection)
: server(server), connection(std::move(connection))
ClientConnection::ClientConnection(Server &server, std::unique_ptr<loom::Connection> connection, bool info_flag)
: server(server), connection(std::move(connection)), info_flag(info_flag)
{
this->connection->set_callback(this);
llog->info("Client {} connected", this->connection->get_peername());
......
......@@ -12,7 +12,7 @@ class Server;
class ClientConnection : public loom::ConnectionCallback {
public:
ClientConnection(Server &server,
std::unique_ptr<loom::Connection> connection);
std::unique_ptr<loom::Connection> connection, bool info_flag);
~ClientConnection();
void on_message(const char *buffer, size_t size);
void on_close();
......@@ -25,9 +25,14 @@ public:
connection->send_buffer(buffer);
}
bool has_info_flag() const {
return info_flag;
}
protected:
Server &server;
std::unique_ptr<loom::Connection> connection;
bool info_flag;
};
......
......@@ -46,15 +46,18 @@ void FreshConnection::on_message(const char *buffer, size_t size)
auto wconn = std::make_unique<WorkerConnection>(server,
std::move(connection),
address.str(),
task_types);
task_types,
msg.cpus());
server.add_worker_connection(std::move(wconn));
server.remove_freshconnection(*this);
return;
}
if (msg.type() == loomcomm::Register_Type_REGISTER_CLIENT) {
bool info_flag = msg.has_info() && msg.info();
auto cconn = std::make_unique<ClientConnection>(server,
std::move(connection));
std::move(connection),
info_flag);
server.add_client_connection(std::move(cconn));
assert(connection.get() == nullptr);
server.remove_freshconnection(*this);
......
......@@ -3,6 +3,7 @@
#include "libloom/utils.h"
#include "libloom/log.h"
#include "libloom/loomcomm.pb.h"
#include <sstream>
......@@ -62,6 +63,26 @@ void Server::remove_freshconnection(FreshConnection &conn)
fresh_connections.erase(i);
}
void Server::on_task_finished(TaskNode &task)
{
assert(client_connection);
if (client_connection->has_info_flag()) {
loomcomm::ClientMessage cmsg;
cmsg.set_type(loomcomm::ClientMessage_Type_INFO);
loomcomm::Info *info = cmsg.mutable_info();
info->set_id(task.get_id());
const auto& owners = task.get_owners();
assert(owners.size());
info->set_worker(owners.back()->get_address());
SendBuffer *buffer = new SendBuffer;
buffer->add(cmsg);
client_connection->send_buffer(buffer);
}
task_manager.on_task_finished(task);
}
void Server::start_listen()
{
struct sockaddr_in addr;
......
......@@ -48,6 +48,8 @@ public:
void add_resend_task(loom::Id id);
void on_task_finished(TaskNode &task);
private:
void start_listen();
......
......@@ -5,6 +5,7 @@
#include "libloom/log.h"
#include <algorithm>
#include <limits.h>
using namespace loom;
......@@ -74,6 +75,11 @@ void TaskManager::on_task_finished(TaskNode &task)
distribute_work(ready);
}
struct _TaskInfo {
int priority;
TaskNode::Vector new_tasks;
};
TaskManager::WorkDistribution TaskManager::compute_distribution(TaskNode::Vector &tasks)
{
WorkDistribution distribution;
......@@ -83,27 +89,44 @@ TaskManager::WorkDistribution TaskManager::compute_distribution(TaskNode::Vector
return distribution;
}
std::unordered_map<WorkerConnection*, TaskNode::Vector> map;
std::unordered_map<WorkerConnection*, _TaskInfo> map;
for (auto &connection : connections) {
int size = connection->get_tasks().size();
int cpus = connection->get_resource_cpus();
map[connection.get()].priority = size - cpus;
}
int c = 0;
size_t size = connections.size();
for (TaskNode* task : tasks) {
auto &inputs = task->get_inputs();
bool done = false;
WorkerConnection *found = nullptr;
int best_priority = INT_MAX;
for (TaskNode *inp : inputs) {
auto &owners = inp->get_owners();
if (owners.size()) {
map[owners[0]].push_back(task);
done = true;
break;
for (WorkerConnection *owner : owners) {
auto &info = map[owner];
if (info.priority < best_priority) {
best_priority = info.priority;
found = owner;
}
}
}
if (done) {
continue;
if (best_priority >= 0) {
for (auto &i : map) {
_TaskInfo &info = i.second;
if (info.priority < best_priority) {
best_priority = info.priority;
found = i.first;
}
}
}
map[connections[c].get()].push_back(task);
c += 1;
c %= size;
auto &info = map[found];
info.new_tasks.push_back(task);
info.priority += 1;
}
/*std::sort(connections.begin(), connections.end(),
......@@ -113,9 +136,11 @@ TaskManager::WorkDistribution TaskManager::compute_distribution(TaskNode::Vector
});*/
distribution.reserve(map.size());
//distribution.reserve(map.size());
for (auto& pair : map) {
distribution.emplace_back(WorkerLoad{*pair.first, std::move(pair.second)});
if (!pair.second.new_tasks.empty()) {
distribution.emplace_back(WorkerLoad{*pair.first, std::move(pair.second.new_tasks)});
}
}
return distribution;
}
......
......@@ -11,13 +11,15 @@ using namespace loom;
WorkerConnection::WorkerConnection(Server &server,
std::unique_ptr<Connection> connection,
const std::string& address,
const std::vector<std::string> &task_types)
const std::vector<std::string> &task_types,
int resource_cpus)
: server(server),
connection(std::move(connection)),
address(address)
resource_cpus(resource_cpus),
address(address)
{
llog->info("Worker {} connected", address);
llog->info("Worker {} connected (cpus={})", address, resource_cpus);
if (this->connection.get()) {
this->connection->set_callback(this);
}
......@@ -34,13 +36,14 @@ void WorkerConnection::on_message(const char *buffer, size_t size)
loomcomm::WorkerResponse msg;
msg.ParseFromArray(buffer, size);
auto it = tasks.find(msg.id());
assert(it != tasks.end());
const auto it = tasks.find(msg.id());
assert(it != tasks.end());
TaskNode *task = it->second;
tasks.erase(it);
task->add_owner(this);
task->set_finished();
server.get_task_manager().on_task_finished(*task);
server.on_task_finished(*task);
}
void WorkerConnection::on_close()
......
......@@ -15,7 +15,8 @@ public:
WorkerConnection(Server &server,
std::unique_ptr<loom::Connection> connection,
const std::string& address,
const std::vector<std::string> &task_types);
const std::vector<std::string> &task_types,
int resource_cpus);
void on_message(const char *buffer, size_t size);
void on_close();
......@@ -41,10 +42,15 @@ public:
assert(0);
}
int get_resource_cpus() const {
return resource_cpus;
}
private:
Server &server;
std::unique_ptr<loom::Connection> connection;
std::unordered_map<loom::Id, TaskNode*> tasks;
int resource_cpus;
std::string address;
std::vector<int> task_type_translates;
......
......@@ -12,6 +12,7 @@ def test_cv_iris(loom_env):
CHUNK_SIZE = 150 / CHUNKS # There are 150 irises
loom_env.start(2)
loom_env.info = True
p = loom_env.plan()
a = p.task_open(IRIS_DATA)
......@@ -38,8 +39,10 @@ def test_cv_iris(loom_env):
task.map_file_in(model, "model")
predict.append(task)
# p.write_dot("test.dot")
results = loom_env.submit(p, predict)
assert len(results) == CHUNKS
for line in results:
assert line.startswith("Accuracy = ")
p.write_dot("test.dot", loom_env.client.info)
......@@ -47,9 +47,14 @@ class Env():
class LoomEnv(Env):
PORT = 19010
info = False
_client = None
def start(self, workers_count, cpus=1):
self.kill_all()
if self.processes:
self._client = None
self.kill_all()
time.sleep(0.2)
server_args = (LOOM_SERVER_BIN,
"--debug",
"--port=" + str(self.PORT))
......@@ -79,11 +84,14 @@ class LoomEnv(Env):
def plan(self):
return client.Plan()
@property
def client(self):
return client.Client("localhost", self.PORT)
if self._client is None:
self._client = client.Client("localhost", self.PORT, self.info)
return self._client
def submit(self, plan, results):
return self.client().submit(plan, results)
return self.client.submit(plan, results)
@pytest.yield_fixture(autouse=True, scope="function")
......
......@@ -9,7 +9,8 @@ typedef
std::unordered_map<WorkerConnection*, TaskSet>
DistMap;
DistMap to_distmap(TaskManager::WorkDistribution dist)
static DistMap
to_distmap(TaskManager::WorkDistribution dist)
{
DistMap map;
for (auto &load : dist) {
......@@ -22,6 +23,14 @@ DistMap to_distmap(TaskManager::WorkDistribution dist)
return map;
}
static std::unique_ptr<WorkerConnection>
simple_worker(Server &server, const std::string &name, int cpus=1)
{
std::vector<std::string> tt;
return std::make_unique<WorkerConnection>(server, nullptr, name, tt, cpus);
}
TEST_CASE( "Server scheduling - separate tasks", "[scheduling]" ) {
Server server(NULL, 0);
TaskManager &manager = server.get_task_manager();
......@@ -42,14 +51,14 @@ TEST_CASE( "Server scheduling - separate tasks", "[scheduling]" ) {
std::vector<std::string> tt;
auto wconn = std::make_unique<WorkerConnection>(server, nullptr, "w1", tt);
auto wconn = simple_worker(server, "w1");
WorkerConnection *w1 = wconn.get();
server.add_worker_connection(std::move(wconn));
auto d2 = to_distmap(manager.compute_distribution(v));
CHECK(d2.size() == 1);
wconn = std::make_unique<WorkerConnection>(server, nullptr, "w2", tt);
wconn = simple_worker(server, "w1");
WorkerConnection *w2 = wconn.get();
server.add_worker_connection(std::move(wconn));
CHECK(server.get_connections().size() == 2);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment