diff --git a/python/loom/client/connection.py b/python/loom/client/connection.py index ab21fec69ef1aef9b92bc2d823625ca157718cab..104bee2c0d77f85a240bd4213975795999db4702 100644 --- a/python/loom/client/connection.py +++ b/python/loom/client/connection.py @@ -14,30 +14,33 @@ class Connection(object): self.socket.close() def receive_message(self): - while True: - size = len(self.data) - if size >= 8: - msg_size = u64.unpack(self.data[:8])[0] - msg_size += 8 - if size >= msg_size: - message = self.data[8:msg_size] - self.data = self.data[msg_size:] - return message - new_data = self.socket.recv(65536) - if not new_data: - raise Exception("Connection to server lost") - self.data += new_data + while len(self.data) < 8: + self.data += self.socket.recv(655360) + msg_size = u64.unpack(self.data[:8])[0] + if len(self.data) >= msg_size + 8: + msg_size = msg_size + 8 + message = self.data[8:msg_size] + self.data = self.data[msg_size:] + return message + self.data = self.data[8:] + return self.read_data(msg_size) def read_data(self, data_size): - result = bytes() + chunks = [] while True: - change = min(data_size, len(self.data)) - result += self.data[:change] - self.data = self.data[change:] - data_size -= change - if data_size == 0: - return result - self.data = self.socket.recv(65536) + if data_size >= len(self.data): + chunks.append(self.data) + data_size -= len(self.data) + if data_size == 0: + self.data = b"" + return b"".join(chunks) + else: + chunks.append(self.data[:data_size]) + self.data = self.data[data_size:] + return b"".join(chunks) + self.data = self.socket.recv(655360) + if not self.data: + raise Exception("Connection to server lost") def send_message(self, data): data = u64.pack(len(data)) + data diff --git a/src/libloomw/data/pyobj.cpp b/src/libloomw/data/pyobj.cpp index 6b1ea9abadc949fd76730fc96775782681b56dc5..37f92782a09bddf229f2d7391e24c12aa1f1a665 100644 --- a/src/libloomw/data/pyobj.cpp +++ b/src/libloomw/data/pyobj.cpp @@ -14,10 +14,7 @@ loom::PyObj::~PyObj() { PyGILState_STATE gstate; gstate = PyGILState_Ensure(); - - assert(obj->ob_refcnt == 1); Py_DecRef(obj); - PyGILState_Release(gstate); } diff --git a/src/server/taskmanager.cpp b/src/server/taskmanager.cpp index f1c85c0b615fff7607131f436aa8b52c1868e334..9bb67a9f3c1cebdfcee2c6d352dedde6dc19ebf8 100644 --- a/src/server/taskmanager.cpp +++ b/src/server/taskmanager.cpp @@ -86,6 +86,8 @@ void TaskManager::remove_node(TaskNode &node) }); node.set_not_needed(); //cstate.remove_node(node); + node.reset_owners(); + cstate.remove_node(node); } void TaskManager::on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing) @@ -296,7 +298,7 @@ void TaskManager::run_task_distribution() } void TaskManager::trash_all_tasks() -{ +{ for (auto &wc : server.get_connections()) { wc->change_residual_tasks(wc->get_checkpoint_loads()); wc->change_checkpoint_loads(-wc->get_checkpoint_loads()); diff --git a/tests/client/test_py.py b/tests/client/test_py.py index f513e8229c7341f4a50342bf9b190ec787a0de39..4c436e36a77ae3c30e71f77a4824c17b1a958c7c 100644 --- a/tests/client/test_py.py +++ b/tests/client/test_py.py @@ -287,3 +287,30 @@ def test_py_task_deserialization3(loom_env): objs = tuple(tasks.py_value(str(i + 1000)) for i in range(100)) x = tasks.array_make(objs) loom_env.submit_and_gather(x) + + +def test_rewrap_test(loom_env): + @tasks.py_task(context=True, n_direct_args=4) + def init(ctx): + content = [1, 2, 3] + return ctx.wrap(content) + + @tasks.py_task(context=True) + def center(ctx, train_test): + train, test = [t.unwrap() for t in train_test] + return [ctx.wrap(t) for t in (train, test)] + + @tasks.py_task(context=True) + def remove_empty_rows(ctx, train, test): + train = list(train.unwrap()) + test = list(test.unwrap()) + return [ctx.wrap(t) for t in (train, test)] + + train = init() + test = init() + mean_task = center(remove_empty_rows(train, test)) + smurff_tasks = [mean_task] + + loom_env.start(1) + futures = loom_env.client.submit(smurff_tasks) + results = loom_env.client.gather(futures) \ No newline at end of file