diff --git a/python/loom/rview/report.py b/python/loom/rview/report.py index 2e1004698bab76cf499ece5689aa26f803f96e4a..b1eefedaeb53ef2f8f144af6a7bb84e67de5b09f 100644 --- a/python/loom/rview/report.py +++ b/python/loom/rview/report.py @@ -44,7 +44,24 @@ class Report: labels.add(task.label) else: labels.add(symbols[task.task_type]) - return sorted(labels) + + label_groups = {} + group_names = [] + labels = sorted(labels) + + for label in labels: + if ":" in label: + group = label.split(":")[0].strip() + else: + group = label + try: + index = group_names.index(group) + except ValueError: + index = len(group_names) + group_names.append(group) + label_groups[label] = index + + return label_groups, group_names def create_graph(self): TASK_START = loomcomm.Event.TASK_START @@ -56,8 +73,8 @@ class Report: if event.type == TASK_START: task_workers[event.id] = event.worker_id - labels = self.collect_labels() - colors = [dot_color(c) for c in generate_colors(len(labels))] + label_groups, group_names = self.collect_labels() + colors = [dot_color(c) for c in generate_colors(len(group_names))] for i, task in enumerate(self.report_msg.plan.tasks): node = graph.node(i) @@ -69,8 +86,8 @@ class Report: if task_workers: node.label += "\nw={}".format(task_workers[i]) - node.fillcolor = colors[labels.index(label)] - node.color = colors[labels.index(label)] + node.fillcolor = colors[label_groups[label]] + node.color = colors[label_groups[label]] for j in task.input_ids: graph.node(j).add_arc(node) return graph @@ -81,7 +98,7 @@ class Report: workers = {} symbols = self.symbols - labels = self.collect_labels() + label_groups, group_names = self.collect_labels() for event in self.report_msg.events: lst = workers.get(event.worker_id) @@ -109,7 +126,7 @@ class Report: colors = [] y_labels = [] - color_list = generate_colors(len(labels)) + color_list = generate_colors(len(group_names)) tasks = self.report_msg.plan.tasks index = 0 @@ -126,7 +143,7 @@ class Report: label = task.label else: label = symbols[task.task_type] - colors.append(color_list[labels.index(label)]) + colors.append(color_list[label_groups[label]]) index += 1 index += 1 @@ -136,4 +153,4 @@ class Report: colors, y_labels, [(l, color_list[i]) - for i, l in enumerate(labels)]) + for i, l in enumerate(group_names)]) diff --git a/tests/client/cv_test.py b/tests/client/cv_test.py index 763ea624e431e8ec717055a34da96856926d26e9..4b792af28f30ba662fca12d9d0b19af73f720813 100644 --- a/tests/client/cv_test.py +++ b/tests/client/cv_test.py @@ -26,10 +26,10 @@ def test_cv_iris(loom_env): for i in xrange(CHUNKS)] models = [] - for ts in trainsets: + for i, ts in enumerate(trainsets): model = tasks.run("svm-train data", [(ts, "data")], ["data.model"]) - model.label = "svm-train" + model.label = "svm-train: {}".format(i) models.append(model) predict = []