#include "stdafx.h"

#include "dtw.h"
#include "calcul.h"
#include "print.h"
#include "help.h"
#include "draw.h"
#include "veSegment.h"

using namespace std;

#undef min
#undef max

///Calculates dtw (main entry point function for dtw method).
///@param[in] input input data
///@param[in] info input data informations
///@param[in] params parameters
///@return dtw results
resultMethod dtw::dtwBase(inputMethod const &input, inputInfo const &info, parameter const &params)
{
	resultMethod result;

	if (params.veWindow) 
	{
		result = dtwSegment(input, info, params);
	}
	else
	{
		return dtwPair(input, info, params);
	}

	return result;
}

///Calculates dtw for the pair of time series.
///@param[in] input input data
///@param[in] info input data informations
///@param[in] params parameters
///@return dtw results
resultMethod dtw::dtwPair(inputMethod const &input, inputInfo const &info, parameter const &params)
{
	if ((int)((input.A.size() * input.B.size()) / 131072) > params.ram) //131072 to convert bytes to MB
	{ 
		cout << "size A: " << input.A.size() << ", size B: " << input.B.size() << endl;
		//throw runtime_error("DTW aborted. Input too large: " + to_string(input.A.size() * input.B.size()) + "B");
		cout << "DTW aborted. Input too large: " << (input.A.size() * input.B.size() * 8) / 1024 / 1024 << "MB" << endl;
		cout << "For overriding RAM limit use -ram [MBs] switch." << endl;
		exit(0);
	}
	auto result = configure(input, params);

	if (params.outDraw.size() > 0)
		draw::plotPair(input, result, info, params);

	if (params.scoreReversed) 
	{
		for (size_t i = 0; i < result.score.size(); i++)
		{
			result.score[params.scoreType] = 1 - result.score[params.scoreType];
		}
	}
	
	return result;
}

///Segments input time series and calculates dtw above all pairs of found sub time series.
///@param[in] input input data
///@param[in] info input data information
///@param[in] params parameters
///@return dtw results
resultMethod dtw::dtwSegment(inputMethod const &input, inputInfo const &info, parameter const &params)
{
	//segments
	auto setA = veSegment::getSegments(input.A, params.veWindow, params.veSmooth);
	auto setB = veSegment::getSegments(input.B, params.veWindow, params.veSmooth);

	vtr2<resultMethod> result(setA.size());
	
	for (size_t i = 0; i < setA.size(); i++)
	{
		result[i] = vtr<resultMethod>(setB.size());
		for (size_t j = 0; j < setB.size(); j++)
		{
			inputMethod subInput(setA[i], setB[j]);
			result[i][j] = configure(subInput, params);
		}
	}
	
	if (params.outDraw.size() > 0)
		draw::plotSegmets(input, result, info, params);

	if (params.scoreReversed)
	{
		for (size_t i = 0; i < result.size(); i++) 
		{
			for (size_t j = 0; j < result[i].size(); j++)
				result[i][j].score[j] -= 1;
		}
	}
	
	return result[0][0];
}

///Chooses type of alignment. 
///@param[in] input input data
///@param[in] params parameters
///@return dtw results
resultMethod dtw::configure(inputMethod const &input, parameter const &params)
{  	
	distancet d(params.distance);
	resultMethod result;
	resultPath warping;
	
	if (params.localAlignment) 
	{
		result = dtw::alignmentLocal(input, d, params);
	}
	else
	{
		result = dtw::alignment(input, d, params);		
	}
	
	return result;
}

///Calculates alignment of input time series. 
///@param[in] input data
///@param[in] dist distance type (Euclid, Manhattan, CSI: chord, chroma)
///@param[in] params parameters
///@return dtw results
resultMethod dtw::alignment(inputMethod const &input, distancet const &dist, parameter const &params)
{
	auto m = dtw::matrix(input, dist, params);
	auto end = getEnds(m, params);
	vtr<resultPath> warping(1);
	warping[0] = getWarping(m, end, params);
	warping[0].scoreRaw = m[end.row][end.col].value;

	vtr<coord> minims;
	if (params.drawMin)
		minims = getMinimums(m, params);

	resultMethod result;
	if (params.outDraw.size() > 0)
	{
		result.matrix_acc = draw::drawCombine(m, warping, minims, params);
		result.matrix_noacc = draw::drawCombine(dtw::matrixNoaccumulation(input, dist, params), warping, minims, params);
	}
	 
	if (params.debugInfo)
	{
		cout << endl << print::distanceMatrix(m);
		cout << endl << warping[0].path;
		//cout << endl << print::printPathShape(back.path, end, (int)A.size() + 1, (int)B.size() + 1);
	}
	
	result.score = getScore(input, warping);
	result.warpings = warping;
	
	return result;
}

///Calculates local alignments for input time series. 
///@param[in] input data
///@param[in] dist distance type (Euclid, Manhattan, CSI: chord, chroma)
///@param[in] params parameters
///@return dtw results
resultMethod dtw::alignmentLocal(inputMethod const &input, distancet const &dist, parameter const &params)
{
	auto m = dtw::matrixNoaccumulation(input, dist, params);
	auto minims = dtw::getMinimums(m, params);
	
	for (auto &i : minims) {
		m[i.row][i.col] = 0;
	}
	
	dtw::accumulateMod(m, minims, params);

	resultMethod result;
	result.warpings = getWarpings(m, minims, params);
	filterWarpings(m, result.warpings, params);

	if (params.outDraw.size() > 0)
	{
		result.matrix_acc = draw::drawCombine(m, result.warpings, minims,  params);
		result.matrix_noacc = draw::drawCombine(dtw::matrixNoaccumulation(input, dist, params), result.warpings, minims, params);
	}

	if (params.debugInfo) 
	{
		cout << endl << print::distanceMatrix(m);
		//cout << endl << warpings[0].path;
		//cout << endl << print::printPathShape(back.path, end, (int)A.size() + 1, (int)B.size() + 1);
		//print::write(print::printHtmlDistanceMatrix<T>(m), "c:\\code\\data\\sc\\dm.html", false);
	}

	result.score = getScore(input, result.warpings);	
	return result;
}

///Calculates all final scores (s1 - s5). 
///@param[in] input input data
///@param[in] warpings warping paths from which final scores are calculated
///@return dtw score results
vtr<double> dtw::getScore(inputMethod const &input, vtr<resultPath> const &warpings)
{
	vtr<double> score(5);
	for (auto &i : warpings) 
	{
		score[0] += i.scoreRaw;
		score[1] += calcul::scoreDtwS2(i.scoreRaw, i.path.size()); // back.path.size());
		score[2] += calcul::scoreDtwS3(i.end.row - i.start.row, i.end.col - i.start.col, i.path.size());
		score[3] += calcul::scoreDtwS4(i.scoreRaw, calcul::scoreDtwMax(input.A, input.B, i.start, i.end));
		score[4] += calcul::scoreDtwS5(i.scoreRaw,
						calcul::scoreDtwMax(input.A, input.B, i.start, i.end),
						calcul::lenRatio(i.end.row - i.start.row, i.end.col - i.start.row));
	}

	for (auto &i : score) {
		i /= warpings.size();
	}

	return score;
}

///Calculates accumulated distance matrix. 
///@param[in] input input data
///@param[in] distance distance type (Euclid, Manhattan, CSI: chord, chroma)
///@param[in] params parameters
///@return 2d distance matrix
vtr2<node> dtw::matrix(inputMethod const &input, distancet const &distance, parameter const &params)
{
	int lenA = (int)input.A.size();
	int lenB = (int)input.B.size();

	vtr2<node> m(lenA + 1);
	for (int i = 0; i < lenA + 1; i++)
		m[i] = vtr<node>(lenB + 1);

	for (int i = 0; i < min(lenA, params.relax + 1); i++) //include 0,0 = 0 !!
	{
		m[i][0].value = 0;
	}

	for (int i = 0; i < min(lenB, params.relax + 1); i++)
	{
		m[0][i].value = 0;
	}

	if (params.isSubsequence() && calcul::lenRatio(lenA, lenB) < params.subsequence) 
	{
		if (lenA < lenB)
		{
			for (int i = 0; i < lenB + 1; i++) {
				m[0][i].value = 0;
			}
		}
		
		if (lenB < lenA) 
		{
			for (int i = 0; i < lenA + 1; i++) {
				m[i][0].value = 0;
			}
		}
	}
	
	const double coef = lenB / static_cast<double>(lenA);
	const int w = (int)(lenB * params.w);
	for (int i = 1; i < lenA + 1; i++) //row - y
	{
		const int start = calcul::dtwPassStart(i, w, coef);
		const int end = calcul::dtwPassEnd(i, w, coef, lenB);
		for (int j = start; j < end; j++) //col - x
		{
			const double u = m[i - 1][j].value;
			const double l = m[i][j - 1].value;
			const double d = m[i - 1][j - 1].value;

			double min = std::min({ u, l, d });
			m[i][j].value = distance.getDistance(input, i, j) + min;
		}
	}

	return m;
}

///Calculates non-accumulated distance matrix. 
///@param[in] input input data
///@param[in] dist distance type (Euclid, Manhattan, CSI: chord, chroma)
///@param[in] params parameters
///@return 2d distance matrix
vtr2<node> dtw::matrixNoaccumulation(inputMethod const &input, distancet const &dist, parameter const &params)
{
	int lenA = (int)input.A.size();
	int lenB = (int)input.B.size();

	vtr2<node> m(lenA + 1);
	for (int i = 0; i < lenA + 1; i++) {
		m[i] = vtr<node>(lenB + 1);
	}

	for (int i = 0; i < min(lenA, params.relax + 1); i++) {
		m[i][0].value = 0;
	}

	for (int i = 0; i < min(lenB, params.relax + 1); i++) {
		m[0][i].value = 0;
	}

	if (params.isSubsequence() && calcul::lenRatio(lenA, lenB) < params.subsequence)
	{
		if (lenA < lenB)
		{
			for (int i = 0; i < lenB + 1; i++) {
				m[0][i].value = 0;
			}
		}

		if (lenB < lenA)
		{
			for (int i = 0; i < lenA + 1; i++) {
				m[i][0].value = 0;
			}
		}
	}

	const int w = (int)(lenB * params.w);
	const double coef = lenB / (double)lenA;
	for (int i = 1; i < lenA + 1; i++) //row = y
	{
		const size_t start = calcul::dtwPassStart(i, w, coef);
		const size_t end = calcul::dtwPassEnd(i, w, coef, lenB);
		for (size_t j = start; j < end; j++) //col = x
		{
			if (dist.type == 1) {
				m[i][j].value = dist.dist.classic(input.A[i - 1], input.B[j - 1]);
			}
			else if (dist.type == 2) {
				m[i][j].value = dist.dist.csiChroma(input.A[i - 1], input.B[j - 1], 0.07);
			}
			else {
				m[i][j].value = dist.dist.csiChord(input.A[i - 1], input.B[j - 1], input.A2[i - 1], input.B2[j - 1]);
			}
		}
	}

	return m;
}

///Accumulates non-accumulated distance matrix. 
///@param[in] m input matrix
///@param[in] params parameters
void dtw::accumulate(vtr2<node> &m, parameter const &params)
{	
	int lenA = (int)m.size() - 1;
	int lenB = (int)m[0].size() - 1;

	const double coef = lenB / (double)lenA;
	const int w = (int)(lenB * params.w);
	for (int i = 1; i < lenA + 1; i++) //row = y
	{
		const size_t start = calcul::dtwPassStart(i, w, coef);
		const size_t end = calcul::dtwPassEnd(i, w, coef, lenB);
		for (size_t j = start; j < end; j++) //col = x
		{
			m[i][j].value += std::min({ m[i - 1][j - 1].value, m[i - 1][j].value, m[i][j - 1].value });
		}
	}
}

///Accumulates non-accumulated distance matrix (version for the distance matrix where minimums (Kocyan) are zeroed).
///@param[in] m non-accumulated distance matrix
///@param[in] minims local minimums found in non-accumulated distance matrix (Kocyan).
///@param[in] params parameters
void dtw::accumulateMod(vtr2<node> &m, vtr<coord> const &minims, parameter const &params)
{
	int lenA = (int)m.size() - 1;
	int lenB = (int)m[0].size() - 1;

	const double coef = lenB / (double)lenA;
	const int w = (int)(lenB * params.w);
	for (size_t i = 1; i < (size_t)lenA + 1; i++) //row = y
	{
		const size_t start = calcul::dtwPassStart(i, w, coef);
		const size_t end = calcul::dtwPassEnd(i, w, coef, lenB);
		for (size_t j = start; j < end; j++) //col = x
		{
			if (m[i][j].value == 0 && std::find(minims.begin(), minims.end(), coord(i, j)) != minims.end()) {
				continue;
			}
			else {
				m[i][j].value += std::min({ m[i - 1][j - 1].value, m[i - 1][j].value, m[i][j - 1].value });
			}
		}
	}
}

///Finds minimums (defined by Kocyan) in non-accumulated distance matrix.
///@param[in] m input distance matrix
///@param[in] params parameters
///@return all minimums found in non-accumulated distance matrix (Kocyan).
vtr<coord> dtw::getMinimums(vtr2<node> const &m, parameter const &params)
{
	auto isMin = [&m = m](int i, int j) 
	{
		if (i + 1 < static_cast<int>(m.size()) && j + 1 < static_cast<int>(m[i].size()) &&
			m[i - 1][j].value > m[i][j].value && m[i - 1][j + 1].value > m[i][j].value && m[i][j - 1].value > m[i][j].value &&
			m[i][j + 1].value > m[i][j].value && m[i + 1][j - 1].value > m[i][j].value && m[i + 1][j].value > m[i][j].value)
		{ 
			return true;
		}

		return false;
	};

	vtr<coord> minims;

	int lenA = (int)m.size() - 1;
	int lenB = (int)m[0].size() - 1;

	const double coef = lenB / (double)lenA;
	const int w = (int)(lenB * params.w);
	for (int i = 1; i < lenA + 1; i++) //row = y
	{
		const int start = calcul::dtwPassStart(i, w, coef);
		const int end = calcul::dtwPassEnd(i, w, coef, lenB);
		for (int j = start; j < end; j++) //col = x
		{
			if (isMin(i, j)) 
			{
				minims.emplace_back(coord(i, j));
				j++;
			}
		}
	}
	
	//filter too similar minimums //if there is min in 8 near cells 
	auto isFilteredMin = [&minims](size_t i, size_t j)
	{
		if (minims[j].row - minims[i].row < 3)		// if dist in rows is larger than 1 ...cant find neigh
		{
			if (minims[j].col - minims[i].col < 3)
				return false;
		}
		return true;
	};

	vtr<coord> filtered;
	for (size_t i = 0; i < minims.size(); i++)
	{
		bool is = false;
		for (size_t j = i + 1; j < minims.size(); j++)	
		{
			is = isFilteredMin(i, j);
			if (!is) 
				break;
		}

		if (is)
			filtered.emplace_back(minims[i]);
	}

	return filtered;
}

///Finds warping path in the accumulated distance matrix. 
///@param[in] m accumulated distance matrix
///@param[in] coords coordinations from where is the warping path searched (bottom right corner -> so its actually end)
///@param[in] params parameters
///@return generated warping path from accumulated distance matrix
resultPath dtw::getWarping(vtr2<node> const &m, coord coords, parameter const &params)
{
	resultPath warping;
	warping.end = coords;

	double lenA = coords.row;
	double lenB = coords.col;
	
	while (coords.row > 1 && coords.col > 1)
	{
		warping.pathCoords.emplace_back(coord(coords.row, coords.col));
		
		double d = 0, u = 0, l = 0;
		if (params.isPassFlexible())
		{
			int kd = (int)warping.pathCoords.size() - params.fd;
			coord past = kd >= 0 ? warping.pathCoords[kd] : warping.pathCoords[0];
			d = calcul::isFlexiblePass(coords.row - 1, coords.col - 1, past, params.fw) ? m[coords.row - 1][coords.col - 1].value : constant::MAX_double;
			u = calcul::isFlexiblePass(coords.row - 1, coords.col, past, params.fw) ? m[coords.row - 1][coords.col].value : constant::MAX_double;
			l = calcul::isFlexiblePass(coords.row, coords.col - 1, past, params.fw) ? m[coords.row][coords.col - 1].value : constant::MAX_double;
		}
		else 
		{
			d = m[coords.row - 1][coords.col - 1].value;
			u = m[coords.row - 1][coords.col].value;
			l = m[coords.row][coords.col - 1].value;
		}

		if (min({ d, u, l }) == d)
		{
			warping.path = "M" + warping.path;
			coords.row--;
			coords.col--;
		}
		else
		{
			if (l < u)
			{
				warping.path = "L" + warping.path;
				coords.col--;
			}
			else if (u < l)
			{
				warping.path = "U" + warping.path;
				coords.row--;
			}
			else
			{
				if(lenA / coords.row > lenB / coords.col)
				{
					warping.path = "L" + warping.path;
					coords.col--;
				}
				else
				{
					warping.path = "U" + warping.path;
					coords.row--;
				}
			}
		}
	}

	if ((!params.localAlignment && !params.isSubsequence()) || (params.isSubsequence() && m.size() < m[0].size()))
	{
		while (coords.row > params.relax + 1)
		{
			warping.path = "U" + warping.path;
			warping.pathCoords.emplace_back(coord(coords.row, coords.col));
			coords.row--;
		}
	}

	if ((!params.localAlignment && !params.isSubsequence()) || (params.isSubsequence() && m[0].size() < m.size()))
	{
		while (coords.col > params.relax + 1)
		{
			warping.path = "L" + warping.path;
			warping.pathCoords.emplace_back(coord(coords.row, coords.col));
			coords.col--;
		}
	}

	warping.path = "M" + warping.path;
	warping.pathCoords.emplace_back(coord(coords.row, coords.col));
	coords.row--;
	coords.col--;
	
	warping.start.row = coords.row; // start is upper left...so its where path is finished building(sort of end)
	warping.start.col = coords.col;

	std::reverse(warping.pathCoords.begin(), warping.pathCoords.end());

	return warping;
}

///Finds all warping paths from the warping paths start coordinations. 
///@param[in] m accumulated distance matrix
///@param[in] minims local minimums found in the non-accumulated distance matrix (Kocyan)
///@param[in] params parameters
///@return generated warping paths
vtr<resultPath> dtw::getWarpings(vtr2<node> const &m, vtr<coord> const &minims, parameter const &params)
{
	vtr<resultPath> paths(minims.size());

	int c = 0;
	for (auto i : minims) //no ref!
	{
		paths[c++] = getWarping(m, i, params);
		paths.back().end = i;
		paths.back().scoreRaw = m[i.row][i.col].value;
	}
	
	return paths;
}

///Filters found warping paths form minimums (local dtw, Kocyan). 
///@param[in] m input matrix
///@param[in] warpings warping paths
///@param[in] params parameters
void dtw::filterWarpings(vtr2<node> const &m, vtr<resultPath> &warpings, parameter const &params)
{
	//filter sub paths
	auto filterSub = [&w = warpings] () 
	{
		vtr<resultPath> filter;
		bool accept = true;
		for (size_t i = 0; i < w.size(); i++)
		{
			accept = true;
			for (size_t j = i + 1; j < w.size(); j++)
			{
				auto found = w[j].path.find(w[i].path);
				if (w[i].path.size() <= w[j].path.size() &&  found != string::npos)
				{
					accept = false;
					break;
				}
			}
			if(accept) 
				filter.emplace_back(w[i]);
		}
		return filter;
	};
	auto filteredWarpings = filterSub();

	//filter sub paths which do not meets threshold parameters
	auto filterThereshold = [&m, &fw = filteredWarpings, &p = params]() 
	{
		vtr2<range> ranges(fw.size());
		for (int k = 0; k < (int)fw.size(); k++) 
		{
			for (int i = 0; i < (int)fw[k].path.size(); i++)
			{
				double max = constant::MIN_double;
				range rangeTmp = range(-1, 0);
				bool accept = false;
				bool insert = true;
				double total = 0;

				if ((int)fw[k].path.size() >= p.tresholdL) {
					for (int j = i; j < (int)fw[k].path.size(); j++)
					{
						int len = j - i + 1;
						double current = m[fw[k].pathCoords[j].row][fw[k].pathCoords[j].col].value - total;
						total += std::max(current, 0.0);

						if (max < current)
							max = current;
						if (max <= p.tresholdE /*&& cTotal <= params.treshold_t*/ && total / len <= p.tresholdA && len >= p.tresholdL)
						{
							accept = true;

							if (!insert)
								if (i - rangeTmp.end < 2)
									rangeTmp.end = j;

							if (insert) {
								rangeTmp = range(i, j);
								insert = false;
							}
						}
					}
				}

				if (accept) //filter acceptable subsequences which are part of already accepted subs (sub of subs)
				{
					for (auto &r : ranges[k])
					{
						if (r.start <= rangeTmp.start && rangeTmp.end <= r.end) 
						{
							accept = false;
						}
						else if (rangeTmp.start - r.end < 2)
						{
							r.end = rangeTmp.end;
							accept = false;
							break;
						}
					}
				}

				if (accept)
					ranges[k].emplace_back(rangeTmp);
			}
		}
		return ranges;
	};
	auto filteredRanges = filterThereshold();

	//filters empty ranges
	auto filterEmpty = [&r = filteredRanges, &fw = filteredWarpings]() 
	{
		vtr2<range> ranges;
		vtr<resultPath> filter;

		for (size_t i = 0; i < r.size(); i++)
		{
			if (!r[i].empty()) 
			{
				ranges.emplace_back(r[i]);
				filter.emplace_back(fw[i]);
			}
		}

		vtr<resultPath> result;
		for (size_t i = 0; i < filter.size(); i++)
		{
			resultPath warp;
			warp = filter[i];
			warp.path = filter[i].path.substr(ranges[i][0].start, ranges[i][0].end - ranges[i][0].start + 1);
			warp.pathCoords = vtr<coord>(filter[i].pathCoords.begin() + ranges[i][0].start, filter[i].pathCoords.begin() + ranges[i][0].end + 1);
			warp.start = filter[i].pathCoords[ranges[i][0].start];
			warp.end = filter[i].pathCoords[ranges[i][0].end];
			result.emplace_back(warp);
		}

		return result;
	};
	filteredWarpings = filterEmpty();
	
	auto filterSame = [&fw = filteredWarpings]() 
	{
		vtr<resultPath> filtered;

		for (size_t i = 0; i < fw.size(); i++)
		{
			bool accept = true;
			for (size_t j = i + 1; j < fw.size(); j++)
			{
				auto found = search(fw[j].pathCoords.begin(), fw[j].pathCoords.end(), fw[i].pathCoords.begin(), fw[i].pathCoords.end());
				if (found != fw[j].pathCoords.end())
					accept = false;

				break;
			}

			if (accept)
				filtered.emplace_back(fw[i]);
		}

		return filtered;
	};
	
	warpings = filterSame();
}

///Finds end of the warping path.
///@param[in] m input matrix
///@param[in] params parameters
///@return warping path end coordinations (bottom right corner).
coord dtw::getEnds(vtr2<node> const &m, parameter const &params)
{
	double min = constant::MAX_double;
	coord coordMin;

	int lenA = (int)m.size() - 1;
	int lenB = (int)m[0].size() - 1;

	int startA = (lenA - params.relax) < 0 ? 0 : lenA - params.relax;
	int startB = (lenB - params.relax) < 0 ? 0 : lenB - params.relax;

	if (params.isSubsequence()) 
	{
		if (lenA < lenB) {
			startB = 0;
		}
		else if (lenB < lenA) {
			startA = 0;
		}
	}

	for (size_t i = startA; i < m.size(); i++)
	{
		if (m[i][lenB].value <= min)
		{
			min = m[i][lenB].value;
			coordMin.row = (int)i;
			coordMin.col = lenB;
		}
	}

	for (size_t i = startB; i < m[0].size(); i++)
	{
		if (m[lenA][i].value <= min)
		{
			min = m[lenA][i].value;
			coordMin.row = lenA;
			coordMin.col = (int)i;
		}
	}

	return coordMin;
}