Skip to content
Snippets Groups Projects
tasks.py 1.02 KiB
Newer Older
  • Learn to ignore specific revisions
  • Jakub Beránek's avatar
    Jakub Beránek committed
    import json
    from pathlib import Path
    from typing import List
    
    from functions import train_model
    
    
    def train_model_task(config: Path, output_path: Path):
        """
        TODO: read the JSON config from `config` and pass its parameters (along with `output_path`)
        to the `train_model` function.
        """
    
    
    def postprocess_task(result_files: List[Path], output_file: Path):
        results = []
        for result in result_files:
            with open(result) as f:
                result = json.load(f)
                results.append(result)
        results = sorted(results, key=lambda r: r["accuracy"], reverse=True)
        with open(output_file, "w") as f:
            print("Training results sorted by accuracy:", file=f)
            for result in results:
                learning_rate = result["parameters"]["learning_rate"]
                batch_size = result["parameters"]["batch_size"]
                accuracy = result["accuracy"]
                print(
                    f"Learning rate={learning_rate}, batch_size={batch_size}: {accuracy * 100.0:.2f}%",
                    file=f
                )