/*
 * Experimental code being tried out
 *
 * Copyright (C) 2005  Enrico Zini <enrico@debian.org>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <tagcoll/experiments.h>

#include <iostream>
#include <fstream>

#include <vector>
#include <set>

using namespace wibble::operators;

namespace std {

template<typename TAG, typename _Traits>
basic_ostream<char, _Traits>& operator<<(basic_ostream<char, _Traits>& out, const std::set<TAG>& tags)
{
	for (typename std::set<TAG>::const_iterator i = tags.begin();
			i != tags.end(); i++)
		if (i == tags.begin())
			out << *i;
		else
			out << ", " << *i;
	return out;
}

}

using namespace std;

namespace tagcoll {

template<typename ITEM, typename TAG>
bool Normalizer<ITEM,TAG>::addToGraph(const Scores<TAG>& scores, const std::set<TAG>& ts1, const std::set<TAG>& ts2)
{
	if (this->tagsets.find(ts2) != this->tagsets.end()
		&& scores.distance(ts1, ts2) <= maxMergeDist)
	{
		distGraph[ts2].push_back(ts1);
		distGraph[ts1].push_back(ts2);

		// Try removing more tags to see if there is still something with distance <= 1.0
		for (typename std::set<TAG>::const_iterator i = ts2.begin();
				i != ts2.end(); i++)
			addToGraph(scores, ts1, ts2 - *i);

		return true;
	}
	return false;
}


template<typename ITEM, typename TAG>
void Normalizer<ITEM,TAG>::buildGraph(const Scores<TAG>& scores)
{
	distGraph.clear();

	for (typename tagsets_t::const_iterator i = this->tagsets.begin();
			i != this->tagsets.end(); i++)
		for (typename std::set<TAG>::const_iterator j = i->first.begin();
				j != i->first.end(); j++)
		{
			std::set<TAG> test = i->first - *j;
			if (addToGraph(scores, i->first, test))
			{
			}
		}

	/*
    -- Build algorithm for fixed-score distance of 1
	for (typename tagsets_t::const_iterator i = this->tagsets.begin();
			i != this->tagsets.end(); i++)
		for (typename std::set<TAG>::const_iterator j = i->first.begin();
				j != i->first.end(); j++)
		{
		std::set<TAG> test = i->first - *j;
			if (this->tagsets.find(test) != this->tagsets.end())
			{
				distGraph[test].push_back(i->first);
				distGraph[i->first].push_back(test);
			}
		}
	*/
}

template<typename ITEM, typename TAG>
void Normalizer<ITEM,TAG>::removeAfterMerge(const std::set<TAG>& ts, const std::set<TAG>& merged)
{
	int size = this->tagsets[ts].size();

	this->tagsets.erase(ts);

	std::set<TAG> removed = ts - merged;
	for (typename std::set<TAG>::const_iterator i = removed.begin(); i != removed.end(); i++)
		this->tags.del(*i, size);

	// Erase ts from all arcs that point to it
	typename distgraph_t::iterator near = distGraph.find(ts);
	if (near != distGraph.end())
		for (typename vector< std::set<TAG> >::const_iterator i = near->second.begin();
				i != near->second.end(); i++)
		{
			typename distgraph_t::iterator other = distGraph.find(*i);
			if (other != distGraph.end())
				for (typename vector< std::set<TAG> >::iterator j = other->second.begin();
						j != other->second.end(); j++)
					if (*j == ts)
					{
						other->second.erase(j);
						break;
					}
		}

	// Erase ts from distgraph
	distGraph.erase(ts);
}

template<typename ITEM, typename TAG>
bool Normalizer<ITEM,TAG>::mergeTagsets(const std::set<TAG>& ts1, const std::set<TAG>& ts2)
{
	std::set<TAG> merge = ts1 & ts2;
	std::set<ITEM> items1 = this->tagsets[ts1];
	std::set<ITEM> items2 = this->tagsets[ts2];
	std::set<ITEM> itemsm = this->tagsets[merge];

	// Don't merge if the result would be too big
	if (items1.size() + items2.size() + itemsm.size() > max_threshold)
		return false;

	// Merge

	removeAfterMerge(ts1, merge);
	removeAfterMerge(ts2, merge);
	
	this->tagsets[merge] |= items1;
	this->tagsets[merge] |= items2;

	/*
	cerr << "Rebuilding graph..." << endl;
	buildGraph();
	cerr << "Built graph." << endl;
	*/

	return true;
}

template<typename ITEM, typename TAG>
void Normalizer<ITEM,TAG>::normalize()
{
	/*
	cerr << "Building graph..." << endl;
	buildGraph();
	cerr << "Built graph." << endl;
	*/

	bool done = false;

	while (!done)
	{
		done = true;

		cerr << "Starting run." << endl;

		vector< std::set<TAG> > smallTagsets;

		// Collect the small tagsets
		for (typename tagsets_t::const_iterator i = this->tagsets.begin();
				i != this->tagsets.end(); i++)
			if (i->second.size() < merge_threshold)
				smallTagsets.push_back(i->first);

		for (size_t i = 0; i < smallTagsets.size(); i++)
		{
			typename distgraph_t::const_iterator near = distGraph.find(smallTagsets[i]);
			if (near == distGraph.end())
				continue;

			// See which of the nearest sets is the smallest
			std::set<TAG> smallest;
			size_t smallest_size = 1000;
			for (size_t j = 0; j < near->second.size(); j++)
			{
				size_t size = this->tagsets[near->second[j]].size();
				if (size < smallest_size)
				{
					smallest = near->second[j];
					smallest_size = size;
				}
			}

			if (mergeTagsets(smallTagsets[i], smallest))
			{
				cerr << i << "/" << smallTagsets.size() << " Merged " << smallTagsets[i] << " and " << smallest << endl;
				//smallTagsets.erase(smallTagsets[i]);
				//smallTagsets.erase(smallest);
				done = false;
			}
		}
	}
}


template<typename ITEM, typename TAG>
int Graph<ITEM,TAG>::getHandle(const std::set<TAG>& node)
{
	typename std::map< std::set<TAG>, int >::iterator i = handles.find(node);
	if (i == handles.end())
	{
		pair<typename std::map< std::set<TAG>, int >::iterator, bool> p = handles.insert(make_pair(node, seq++));
		i = p.first;
	}
	return i->second;
}

template<typename TAG>
static string formatNode(const std::set<TAG>& node)
{
//#if 0
	string res;
	for (typename std::set<TAG>::const_iterator i = node.begin();
			i != node.end(); i++)
		if (i == node.begin())
			res += *i;
		else
			res += "\\n" + *i;
	return res;
//#endif
	//return "node";
}

template<typename TAG>
static string formatItems(const std::set<TAG>& node)
{
	string res;
	for (typename std::set<TAG>::const_iterator i = node.begin();
			i != node.end(); i++)
		if (i == node.begin())
			res += *i;
		else
			res += " " + *i;
	return res;
}

template<typename ITEM, typename TAG>
void Graph<ITEM,TAG>::buildSubGraph(std::ostream& out, const std::set<TAG>& node, std::set< std::set<TAG> >& selected, int maxdist, int maxlev)
{
	if (maxlev == 0)
		return;

	// Lay out the nodes in order of distance
	for (int i = 1; i <= maxdist && !selected.empty(); i++)
	{
		// First connect all nodes that are still left at this distance
		vector< std::set<TAG> > connected;
		for (typename std::set< std::set<TAG> >::const_iterator j = selected.begin();
				j != selected.end(); j++)
			if (set_distance(*j, node) == i)
			{
				std::set<TAG> added = *j - node;
				std::set<TAG> removed = node - *j;
				string diff;
				for (typename std::set<TAG>::const_iterator n = added.begin();
						n != added.end(); n++)
					if (diff.empty())
						diff += "+" + *n;
					else
						diff += "\\n+" + *n;
				for (typename std::set<TAG>::const_iterator n = removed.begin();
						n != removed.end(); n++)
					if (diff.empty())
						diff += "-" + *n;
					else
						diff += "\\n-" + *n;
				
				// Connect to the main node
				out << "node" << getHandle(node) << "--node" << getHandle(*j) << "[" <<
					"label=\"" << diff << "\"," <<
					"fontsize=8," <<
					"weight=\"" << i << "\"" <<
					"];" << endl;
				connected.push_back(*j);
			}
		
		// Remove the nodes we just connected
		for (typename vector< std::set<TAG> >::const_iterator j = connected.begin();
				j != connected.end(); j++)
			selected.erase(*j);

		// Then build the subgraphs for the nodes we just connected, and remove
		// the nodes added in the subgraphs
		std::set< std::set<TAG> > removed;
		for (typename vector< std::set<TAG> >::const_iterator j = connected.begin();
				j != connected.end(); j++)
		{
			std::set< std::set<TAG> > subselected(selected);
			buildSubGraph(out, *j, subselected, maxdist - i, maxlev - 1);
			removed |= selected - subselected;
		}

		selected -= removed;
	}
}


template<typename ITEM, typename TAG>
void Graph<ITEM,TAG>::buildGraph(std::ostream& out, const std::set<TAG>& node, int maxdist, int maxlev)
{
	int dist;

	out << "root=node" << getHandle(node) << ";" << endl;

	out << "node" << getHandle(node) << "[" <<
		"label=\"" << formatNode(node) << "\"," <<
		"tooltip=\"" << formatItems(this->tagsets[node]) << "\"," <<
		"color=red," <<
		"fontsize=" << (maxdist + 3) * 2 <<
		"];" << endl;

	// Choose and output the nodes that will go in the graph
	std::set< std::set<TAG> > selected;
	for (typename tagsets_t::const_iterator i = this->tagsets.begin();
			i != this->tagsets.end(); i++)
		if ((dist = set_distance(i->first, node)) > 0 && dist <= maxdist)
		{
			selected.insert(i->first);

			out << "node" << getHandle(i->first) << "[" <<
				"label=\"" << formatNode(i->first) << "\"," <<
				"URL=\"node" << getHandle(i->first) << ".html\"," <<
				"tooltip=\"" << formatItems(this->tagsets[i->first]) << "\"," <<
				"fontsize=" << (maxdist - dist + 3) * 2 <<
				"];" << endl;
		}

	buildSubGraph(out, node, selected, maxdist, maxlev);

#if 0
		if (i->second.distance(node) == 1)
		{
			cout << "node" << getHandle(i->second) << ";" << endl;
			cout << getHandle(node) << "--" << getHandle(i->second) << ";" << endl;
		}
#endif
}

template<typename ITEM, typename TAG>
void Graph<ITEM,TAG>::buildGraphs(const std::string& dir, int maxdist)
{
	for (typename tagsets_t::const_iterator i = this->tagsets.begin();
			i != this->tagsets.end(); i++)
	{
		string fname = dir + "/node";
		char buf[20];
		snprintf(buf, 20, "%d", getHandle(i->first));
		fname += string(buf) + ".dot";
		ofstream out(fname.c_str(), ios::out | ios::trunc);

		out << "strict graph {" << endl;
		out << "overlap=scale;" << endl;
		out << "splines=true;" << endl;
		buildGraph(out, i->first, maxdist);
		out << "}" << endl;
	}
}

}


#ifndef INSTANTIATING_TEMPLATES
#ifdef COMPILE_TESTSUITE
  #define OLD_COMPILE_TESTSUITE 
  #undef COMPILE_TESTSUITE
#endif

#include <string>

namespace tagcoll {
	    template class Normalizer<std::string, std::string>;
	    template class Scores<std::string>;
	    template class Graph<std::string, std::string>;
	    template class std::set< std::set<std::string> >;
}
#ifdef OLD_COMPILE_TESTSUITE
  #define COMPILE_TESTSUITE 
  #undef OLD_COMPILE_TESTSUITE
#endif
#endif


#ifdef COMPILE_TESTSUITE

#include <tests/test-utils.h>

namespace wibble {
namespace tut {
using namespace tagcoll::tests;

struct tagcoll_experiments_shar {
};
TESTGRP(tagcoll_experiments);

template<> template<>
void to::test<1>()
{
	//ensure(false);
}

}
}

#endif

// vim:set ts=4 sw=4:
