From 201ba8e46f3d5528a6915e1b03d05f71d2dda423 Mon Sep 17 00:00:00 2001 From: Stanislav Bohm <stanislav.bohm@vsb.cz> Date: Thu, 27 Oct 2016 10:08:00 +0200 Subject: [PATCH] ENH: ":" delimiter in labels --- python/loom/rview/report.py | 35 ++++++++++++++++++++++++++--------- tests/client/cv_test.py | 4 ++-- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/python/loom/rview/report.py b/python/loom/rview/report.py index 2e10046..b1eefed 100644 --- a/python/loom/rview/report.py +++ b/python/loom/rview/report.py @@ -44,7 +44,24 @@ class Report: labels.add(task.label) else: labels.add(symbols[task.task_type]) - return sorted(labels) + + label_groups = {} + group_names = [] + labels = sorted(labels) + + for label in labels: + if ":" in label: + group = label.split(":")[0].strip() + else: + group = label + try: + index = group_names.index(group) + except ValueError: + index = len(group_names) + group_names.append(group) + label_groups[label] = index + + return label_groups, group_names def create_graph(self): TASK_START = loomcomm.Event.TASK_START @@ -56,8 +73,8 @@ class Report: if event.type == TASK_START: task_workers[event.id] = event.worker_id - labels = self.collect_labels() - colors = [dot_color(c) for c in generate_colors(len(labels))] + label_groups, group_names = self.collect_labels() + colors = [dot_color(c) for c in generate_colors(len(group_names))] for i, task in enumerate(self.report_msg.plan.tasks): node = graph.node(i) @@ -69,8 +86,8 @@ class Report: if task_workers: node.label += "\nw={}".format(task_workers[i]) - node.fillcolor = colors[labels.index(label)] - node.color = colors[labels.index(label)] + node.fillcolor = colors[label_groups[label]] + node.color = colors[label_groups[label]] for j in task.input_ids: graph.node(j).add_arc(node) return graph @@ -81,7 +98,7 @@ class Report: workers = {} symbols = self.symbols - labels = self.collect_labels() + label_groups, group_names = self.collect_labels() for event in self.report_msg.events: lst = workers.get(event.worker_id) @@ -109,7 +126,7 @@ class Report: colors = [] y_labels = [] - color_list = generate_colors(len(labels)) + color_list = generate_colors(len(group_names)) tasks = self.report_msg.plan.tasks index = 0 @@ -126,7 +143,7 @@ class Report: label = task.label else: label = symbols[task.task_type] - colors.append(color_list[labels.index(label)]) + colors.append(color_list[label_groups[label]]) index += 1 index += 1 @@ -136,4 +153,4 @@ class Report: colors, y_labels, [(l, color_list[i]) - for i, l in enumerate(labels)]) + for i, l in enumerate(group_names)]) diff --git a/tests/client/cv_test.py b/tests/client/cv_test.py index 763ea62..4b792af 100644 --- a/tests/client/cv_test.py +++ b/tests/client/cv_test.py @@ -26,10 +26,10 @@ def test_cv_iris(loom_env): for i in xrange(CHUNKS)] models = [] - for ts in trainsets: + for i, ts in enumerate(trainsets): model = tasks.run("svm-train data", [(ts, "data")], ["data.model"]) - model.label = "svm-train" + model.label = "svm-train: {}".format(i) models.append(model) predict = [] -- GitLab