Newer
Older
import json
from pathlib import Path
from typing import List
from functions import train_model
def train_model_task(config: Path, output_path: Path):
"""
TODO: read the JSON config from `config` and pass its parameters (along with `output_path`)
to the `train_model` function.
"""
def postprocess_task(result_files: List[Path], output_file: Path):
results = []
for result in result_files:
with open(result) as f:
result = json.load(f)
results.append(result)
results = sorted(results, key=lambda r: r["accuracy"], reverse=True)
with open(output_file, "w") as f:
print("Training results sorted by accuracy:", file=f)
for result in results:
learning_rate = result["parameters"]["learning_rate"]
batch_size = result["parameters"]["batch_size"]
accuracy = result["accuracy"]
print(
f"Learning rate={learning_rate}, batch_size={batch_size}: {accuracy * 100.0:.2f}%",
file=f
)