Skip to content
Snippets Groups Projects
main.py 1.67 KiB
Newer Older
  • Learn to ignore specific revisions
  • Jakub Beránek's avatar
    Jakub Beránek committed
    import glob
    import shutil
    from pathlib import Path
    
    from hyperqueue import Job, LocalCluster
    from hyperqueue.visualization import visualize_job
    
    from tasks import postprocess_task, train_model_task
    
    if __name__ == "__main__":
        # Spawn a HQ server
        with LocalCluster() as cluster:
            # Add a single HyperQueue worker to the server
            cluster.start_worker()
    
            # Create a client and a job
            client = cluster.client()
            job = Job()
    
            # Directory where output will be stored
            output_dir = Path("output")
            shutil.rmtree(output_dir, ignore_errors=True)
            output_dir.mkdir(parents=True, exist_ok=True)
    
            train_tasks = []
            result_files = []
            for (index, config_file) in enumerate(sorted(glob.glob("configs/*.json"))):
                # TODO: add a Python function task to the job. The task should compute the
                # `train_model_task` function. The path to the config file, and the path to a result
                # file should be passed as arguments to the task. Create a unique result path for each
                # task (e.g. output/<index>-result.json).
    
            # TODO: create the final postprocessing task, which should execute the `postprocessing_task`
            # function. The task should receive a list of results from the model training and a path to
            # the final result file (`postprocessing_result_path`) as arguments.
    
    
            # Submit the job
            submitted = client.submit(job)
    
            # Visualize the created job using the DOT format.
            # You can render the graph using `$ xdot job.dot`.
            visualize_job(job, "job.dot")
    
            # Wait until the job completes
            client.wait_for_jobs([submitted])