Commit 201ba8e4 authored by Stanislav Bohm's avatar Stanislav Bohm

ENH: ":" delimiter in labels

parent 451319ed
......@@ -44,7 +44,24 @@ class Report:
labels.add(task.label)
else:
labels.add(symbols[task.task_type])
return sorted(labels)
label_groups = {}
group_names = []
labels = sorted(labels)
for label in labels:
if ":" in label:
group = label.split(":")[0].strip()
else:
group = label
try:
index = group_names.index(group)
except ValueError:
index = len(group_names)
group_names.append(group)
label_groups[label] = index
return label_groups, group_names
def create_graph(self):
TASK_START = loomcomm.Event.TASK_START
......@@ -56,8 +73,8 @@ class Report:
if event.type == TASK_START:
task_workers[event.id] = event.worker_id
labels = self.collect_labels()
colors = [dot_color(c) for c in generate_colors(len(labels))]
label_groups, group_names = self.collect_labels()
colors = [dot_color(c) for c in generate_colors(len(group_names))]
for i, task in enumerate(self.report_msg.plan.tasks):
node = graph.node(i)
......@@ -69,8 +86,8 @@ class Report:
if task_workers:
node.label += "\nw={}".format(task_workers[i])
node.fillcolor = colors[labels.index(label)]
node.color = colors[labels.index(label)]
node.fillcolor = colors[label_groups[label]]
node.color = colors[label_groups[label]]
for j in task.input_ids:
graph.node(j).add_arc(node)
return graph
......@@ -81,7 +98,7 @@ class Report:
workers = {}
symbols = self.symbols
labels = self.collect_labels()
label_groups, group_names = self.collect_labels()
for event in self.report_msg.events:
lst = workers.get(event.worker_id)
......@@ -109,7 +126,7 @@ class Report:
colors = []
y_labels = []
color_list = generate_colors(len(labels))
color_list = generate_colors(len(group_names))
tasks = self.report_msg.plan.tasks
index = 0
......@@ -126,7 +143,7 @@ class Report:
label = task.label
else:
label = symbols[task.task_type]
colors.append(color_list[labels.index(label)])
colors.append(color_list[label_groups[label]])
index += 1
index += 1
......@@ -136,4 +153,4 @@ class Report:
colors,
y_labels,
[(l, color_list[i])
for i, l in enumerate(labels)])
for i, l in enumerate(group_names)])
......@@ -26,10 +26,10 @@ def test_cv_iris(loom_env):
for i in xrange(CHUNKS)]
models = []
for ts in trainsets:
for i, ts in enumerate(trainsets):
model = tasks.run("svm-train data",
[(ts, "data")], ["data.model"])
model.label = "svm-train"
model.label = "svm-train: {}".format(i)
models.append(model)
predict = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment