#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import sys
import math
import os

# Generate a json-formatted problem from a TSPTW/VRPTW file.

# Those benchmarks use double precision for matrix costs (and input
# timings), and results are usually reported with 2 decimal places. As
# a workaround, we multiply all costs/timings by CUSTOM_PRECISION
# before performing the usual integer rounding. Comparisons in
# benchmarks/compare_to_BKS.py are adjusted accordingly.
CUSTOM_PRECISION = 10

# argv[1] Folder of VRPTW Solomon/Homberger benchmarks from vroom-scripts

# argv[2] Json file with minimal settings:
# "Vehicle Definitions"
# "Available Fleets"
# "Problem-Vehicle Combinations Mapping"

# argv[3] Output file name

# argv[4] Scaling factor, Multiplies the values in matrices


def nint(x):
    return int(x + 0.5)


def euc_2D(c1, c2, PRECISION=1):
    xd = c1[0] - c2[0]
    yd = c1[1] - c2[1]
    return nint(PRECISION * math.sqrt(xd * xd + yd * yd))


line_no = 0


def get_matrix(coords, PRECISION=1):
    N = len(coords)
    matrix = [[0 for i in range(N)] for j in range(N)]

    for i in range(N):
        for j in range(i + 1, N):
            value = euc_2D(coords[i], coords[j], PRECISION)
            matrix[i][j] = value
            matrix[j][i] = value

    return matrix


def parse_meta(lines, meta):
    global line_no
    while len(lines) > 0:
        line = lines.pop(0).strip()
        line_no += 1
        if len(line) == 0:
            continue
        elif "CUSTOMER" in line or "CUST NO." in line:
            lines.insert(0, line)
            line_no -= 1
            break
        elif "NUMBER" in line:
            continue
        else:
            x = line.split()
            if len(x) < 2:
                print("Cannot understand line " + str(line_no) + ": too few columns.")
                exit(2)
            meta["VEHICLES"] = int(x[0])
            meta["CAPACITY"] = int(x[1])


def parse_jobs(lines, jobs, coords):
    global line_no
    location_index = 0
    while len(lines) > 0:
        line = lines.pop(0).strip()
        line_no += 1
        if len(line) == 0:
            continue
        elif "CUST " in line:
            continue
        else:
            x = line.split()
            if len(x) < 7:
                print("Cannot understand line " + str(line_no) + ": too few columns.")
                exit(2)
            # some guys use '999' entry as terminator sign and others don't
            elif "999" in x[0] and len(jobs) < 999:
                break
            coords.append([float(x[1]), float(x[2])])
            jobs.append(
                {
                    "id": int(x[0]),
                    "location": [float(x[1]), float(x[2])],
                    "location_index": location_index,
                    "delivery": [int(float(x[3]))],
                    "time_windows": [
                        [
                            CUSTOM_PRECISION * int(float(x[4])),
                            CUSTOM_PRECISION * int(float(x[5])),
                        ]
                    ],
                    "service": CUSTOM_PRECISION * int(x[6]),
                }
            )
            location_index += 1


def parse_vrptw(input_file):
    global line_no

    with open(input_file, "r") as f:
        lines = f.readlines()

    meta = {}
    while len(lines) > 0:
        line = lines.pop(0).strip()
        line_no += 1
        if len(line) > 0:
            if "#NUM" in line:
                lines.insert(0, line)
                meta["NAME"] = input_file
            else:
                meta["NAME"] = line
            break

    coords = []
    jobs = []

    while len(lines) > 0:
        line = lines.pop(0)
        line_no += 1
        if "VEHICLE" in line:
            parse_meta(lines, meta)
        elif "CUSTOMER" in line or "CUST " in line or "#NUM" in line:
            parse_jobs(lines, jobs, coords)

    matrix = get_matrix(coords, CUSTOM_PRECISION)

    j = jobs.pop(0)

    total_demand = 0
    time_min = ~0
    time_max = 0
    for n in range(len(jobs)):
        total_demand += jobs[n]["delivery"][0]
        for t in jobs[n]["time_windows"]:
            if t[0] - matrix[0][n] < time_min:
                time_min = t[0] - matrix[0][n]
            if t[1] + matrix[n][0] > time_max:
                time_max = t[1] + matrix[n][0]

    return {
        "jobs": jobs,
        "matrices": {"car": {"durations": matrix}},
    }


if __name__ == "__main__":
    input_file_folder_path = sys.argv[1]
    # input_file_folder_path = "/home/david/vroom-scripts/benchmarks/VRPTW/solomon"
    vehicle_fleet_settings = sys.argv[2]
    # vehicle_fleet_settings = os.getcwd() + "/json/minimal_json_definitions.json"
    output_file_name = sys.argv[3]
    # output_file_name = "generated_preprocessed_from_scripts.json"
    # output_name = input_file[: input_file.rfind(".txt")] + ".json"
    scaling_factor = sys.argv[4]
    scaling_factor = int(scaling_factor)
    if scaling_factor <= 0:
        raise TypeError("4th argument is not positive integer.")

    # print("- Writing problem " + input_file + " to " + output_name)
    name_of_folder = input_file_folder_path.split("/")
    name_of_folder = name_of_folder[-1]
    json_file = open(vehicle_fleet_settings, "r")
    # json_file = open("json/Ostrava.json", "r")
    data = json_file.read()
    json_file.close()
    json_data = json.loads(data)
    files = os.listdir(input_file_folder_path)
    json_data.update({"Problem Definitions": {}})
    problems = []
    for file in files:
        json_input = parse_vrptw(input_file_folder_path + "/" + file)
        name_of_problem = file.replace(".txt", "")
        problem_definition_name = name_of_folder + "_" + name_of_problem
        if "jobs" and "shipments" in json_input.keys():
            inner_dict = {
                problem_definition_name: {
                    "Customers": {
                        "jobs": json_input["jobs"],
                        "shipments": json_input["shipments"],
                    },
                    "Matrices": json_input["matrices"],
                }
            }
        elif "jobs" in json_input.keys():
            inner_dict = {
                problem_definition_name: {
                    "Customers": {"jobs": json_input["jobs"]},
                    "Matrices": json_input["matrices"],
                }
            }
        elif "shipments" in json_input.keys():
            inner_dict = {
                problem_definition_name: {
                    "Customers": {"shipments": json_input["shipments"]},
                    "Matrices": json_input["matrices"],
                }
            }
        json_data["Problem Definitions"].update(inner_dict)
    # json_input = parse_vrptw(input_file)

    # index=0
    problem_definitions = json_data["Problem Definitions"]
    for k, v in problem_definitions.items():
        # scaling time_windows at customers
        if "jobs" in v["Customers"]:
            for i in range(len(v["Customers"]["jobs"])):
                #         json_data["Problem Definitions"][k]['Customers']['jobs'][i]['service']=round(problem_definitions[k]['Customers']['jobs'][i]['service']*scaling_factor)
                json_data["Problem Definitions"][k]["Customers"]["jobs"][i][
                    "time_windows"
                ][0][0] = (
                    json_data["Problem Definitions"][k]["Customers"]["jobs"][i][
                        "time_windows"
                    ][0][0]
                    * scaling_factor
                )
                json_data["Problem Definitions"][k]["Customers"]["jobs"][i][
                    "time_windows"
                ][0][1] = (
                    json_data["Problem Definitions"][k]["Customers"]["jobs"][i][
                        "time_windows"
                    ][0][1]
                    * scaling_factor
                )
        if "shipments" in v["Customers"]:
            for i in range(len(v["Customers"]["shipments"])):
                #         json_data["Problem Definitions"][k]['Customers']['shipments'][i]['service']=round(problem_definitions[k]['Customers']['shipments'][i]['service']*scaling_factor)
                json_data["Problem Definitions"][k]["Customers"]["shipments"][i][
                    "time_windows"
                ][0][0] = (
                    json_data["Problem Definitions"][k]["Customers"]["shipments"][i][
                        "time_windows"
                    ][0][0]
                    * scaling_factor
                )
                json_data["Problem Definitions"][k]["Customers"]["shipments"][i][
                    "time_windows"
                ][0][1] = (
                    json_data["Problem Definitions"][k]["Customers"]["shipments"][i][
                        "time_windows"
                    ][0][1]
                    * scaling_factor
                )

    for k, v in problem_definitions.items():
        for i in range(len(v["Matrices"])):
            single_matrix = v["Matrices"]["car"]["durations"]
            for j in range(len(single_matrix)):
                single_row = single_matrix[j]
                for l in range(len(single_row)):
                    json_data["Problem Definitions"][k]["Matrices"]["car"]["durations"][
                        j
                    ][l] = round(
                        json_data["Problem Definitions"][k]["Matrices"]["car"][
                            "durations"
                        ][j][l]
                        * scaling_factor
                    )
    if ".json" not in output_file_name:
        output_file_name.replace(".", "")
        output_file_name = output_file_name + ".json"
    with open(output_file_name, "w") as out:
        json.dump(json_data, out)