Commit abe98999 authored by Stanislav Bohm's avatar Stanislav Bohm

ENH: Client function: make_dry_report

parent d43381d0
from .client import Client, LoomException, TaskFailed # noqa
from .client import Client, LoomException, TaskFailed, make_dry_report # noqa
from .plan import Plan # noqa
from .planbuilder import PlanBuilder, cpus, cpu1 # noqa
......@@ -81,7 +81,7 @@ class Client(object):
assert 0
if report:
self._write_report(report_data, report)
write_report(report_data, report)
if single_result:
return data[results.id]
......@@ -100,10 +100,6 @@ class Client(object):
plan.set_message(report_msg.plan, self.symbols)
return report_msg
def _write_report(self, report_data, report_filename):
with open(report_filename + ".report", "w") as f:
f.write(report_data.SerializeToString())
def _read_symbols(self):
msg = self.connection.receive_message()
cmsg = ClientMessage()
......@@ -138,3 +134,23 @@ class Client(object):
def _send_message(self, message):
data = message.SerializeToString()
self.connection.send_message(data)
def make_dry_report(plan, report_filename):
# Create symbols
symbols = sorted(plan.collect_symbols())
symbol_table = {}
for i, s in enumerate(symbols):
symbol_table[s] = i
# Create report
report_data = Report()
report_data.symbols.extend(symbols)
plan.set_message(report_data.plan, symbol_table)
write_report(report_data, report_filename)
def write_report(report_data, report_filename):
with open(report_filename, "w") as f:
f.write(report_data.SerializeToString())
......@@ -9,6 +9,7 @@ POLICY_SCHEDULER = loomplan_pb2.Task.POLICY_SCHEDULER
class Task(object):
task_type = None
inputs = ()
id = None
config = ""
......@@ -65,6 +66,14 @@ class Plan(object):
self.tasks.append(task)
return task
def collect_symbols(self):
symbols = set()
for task in self.tasks:
if task.resource_request:
symbols.update(task.resource_request.names)
symbols.add(task.task_type)
return symbols
def set_message(self, msg, symbols):
requests = set()
for task in self.tasks:
......
......@@ -5,6 +5,7 @@ import os
IRIS_DATA = os.path.join(LOOM_TEST_DATA_DIR, "iris.data")
loom_env # silence flake8
import client # noqa
def test_cv_iris(loom_env):
......@@ -37,7 +38,9 @@ def test_cv_iris(loom_env):
[(chunk, "testdata"), (model, "model")])
predict.append(task)
results = loom_env.submit(p, predict, report="cv")
loom_env.make_dry_report(p.plan, "dry.report")
results = loom_env.submit(p, predict, report="cv.report")
assert len(results) == CHUNKS
for line in results:
......
......@@ -97,6 +97,12 @@ class LoomEnv(Env):
report = os.path.join(LOOM_TEST_BUILD_DIR, report)
return self.client.submit(plan, results, report)
def make_dry_report(self, plan, filename):
if isinstance(plan, client.PlanBuilder):
plan = plan.plan
filename = os.path.join(LOOM_TEST_BUILD_DIR, filename)
return client.make_dry_report(plan, filename)
@pytest.yield_fixture(autouse=True, scope="function")
def loom_env():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment