/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2007  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include "pch.h"

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "HeuristicScore.h"
#include "URI.h"
#include "ArraySize.h"
#include "StringUtils.h"
#include "InsensitiveEqual.h"
#include "BString.h"
#include "BStringPOD.h"
#include "SplittableBuffer.h"
#include "SBOutStream.h"
#include "GlobalState.h"
#include "CombinedUrlPatterns.h"
#include <stddef.h>
#include <string>
#include <cctype>
#include <cstdlib>
#include <algorithm>

using namespace std;

struct HeuristicScore::SizeRecord
{
	int width;
	int height;
	bool is_ad;
	int score;
};

struct HeuristicScore::SizeRecordComparator
{
	enum { MAX_DEVIATION = 2 };
	
	SizeRecordComparator() {}
	
	bool operator()(SizeRecord const& lhs, SizeRecord const& rhs) {
		if (abs(lhs.width - rhs.width) > MAX_DEVIATION) {
			return lhs.width < rhs.width;
		} else if (abs(lhs.height - rhs.height) > MAX_DEVIATION) {
			return lhs.height < rhs.height;
		}
		return false;
	}
};


HeuristicScore::SizeRecord const HeuristicScore::m_sCommonSizes[] = {
	// sorted by width, height
	{ 88,  31,  true, 80  },
	{ 100, 100, false, 30 },
	{ 120, 60,  true, 100 },
	{ 120, 90,  false, 30 },
	{ 120, 120, false, 30 },
	{ 120, 240, true, 120 },
	{ 120, 300, true, 120 },
	{ 120, 600, true, 120 },
	{ 125, 125, false, 30 },
	{ 160, 600, true, 120 },
	{ 180, 150, true, 100 },
	{ 200, 600, true, 80  },
	{ 234, 60,  true, 120 },
	{ 234, 120, true, 120 },
	{ 240, 400, true, 120 },
	{ 250, 250, false, 30 },
	{ 300, 250, false, 30 },
	{ 336, 280, true, 120 },
	{ 468, 60,  true, 120 },
	{ 468, 80,  true, 120 },
	{ 728, 90,  true, 120 }
};


HeuristicScore::Status
HeuristicScore::getStatus() const
{
	int score = getNumericScore();
	if (score < 20) {
		return NOT_AD;
	} else if (score < 60) {
		return PROBABLY_NOT_AD;
	} else if (score < 100) {
		return PROBABLY_AD;
	} else {
		return AD;
	}
}

HeuristicScore::UrlFlags
HeuristicScore::getUrlStatus(URI const& url, ExpectedType const type)
{
	UrlFlags flags = (UrlFlags)0;
	
	InsensitiveEqual const ieq;
	BString const path(url.getDecodedPath());
	BString const query(url.getDecodedQuery());
	BString const last_component(extractLastPathComponent(path));
	BString const extension(extractExtension(last_component));
	
	{
		SBOutStream path_plus_query(path.size() + 1 + query.size());
		path_plus_query << path << '?' << query;
		if (!path_plus_query.data().ciFind(BString("http://")).isAtRightBorder()) {
			flags |= HAS_EMBEDDED_URL;
		}
	}
	
	if (!extensionMatchesType(extension, type)) {
		flags |= UNEXPECTED_EXTENSION;
	}
	if (ieq(extension, BString("gif"))) {
		flags |= GIF_EXTENSION;
	} else if (ieq(extension, BString("jpg"))
	           || ieq(extension, BString("jpeg"))) {
		flags |= JPEG_EXTENSION;
	} else if (ieq(extension, BString("html"))
	           || ieq(extension, BString("jpg"))) {
		flags |= HTML_EXTENSION;
	}
	
	if (url.hasQuery()) {
		flags |= HAS_QUERY;
	} else {
		// detect a path like this: /dir/script.php/more/data
		BString p(path);
		while (!p.empty()) {
			chopLastPathComponent(p);
			BString const component(extractLastPathComponent(p));
			BString const ext(extractExtension(component));
			if (ext.size() == 3) {
				flags |= PROBABLY_HAS_QUERY;
				break;
			}
		}
	}
	
	if (url.isAbsolute() && url.getRawPath().empty() && !url.hasQuery()) {
		flags |= FRONT_PAGE;
	}
	
	BString const cgi_bin("cgi-bin");
	if (StringUtils::startsWith(path.begin(), path.end(),
	                            cgi_bin.begin(), cgi_bin.end())) {
		flags |= CGI_BIN;
	}
	
	return flags;
}

HeuristicScore::UrlRelationship
HeuristicScore::getUrlRelationship(URI const& url, URI const& base)
{
	if (!url.isAbsolute()) {
		return URLS_SAME_HOST;
	}
	if (!InsensitiveEqual()(url.getScheme(), BString("http")) &&
	    !InsensitiveEqual()(url.getScheme(), BString("https"))) {
		// this will catch javascript: and about: urls
		return URLS_RELATED; 
	}

	BString const& host1 = url.getHost();
	BString const& host2 = base.getHost();
	
	return getDomainRelationship(
		host1.begin(), host1.end(), host2.begin(), host2.end()
	);
}

bool
HeuristicScore::isCommonAdSize(int width, int height)
{
	SizeRecord const* rec = findSizeRecord(width, height);
	return rec && rec->is_ad;
}

int
HeuristicScore::getSizeScore(int width, int height)
{
	if ((width != -1 && width < 30) || (height != -1 && height < 15)) {
		// too small
		return -100;
	}
	
	SizeRecord const* rec = findSizeRecord(width, height);
	return rec ? rec->score : 0;
}

int
HeuristicScore::getHintModifier(URI const& url)
{
	return GlobalState::ReadAccessor()->urlPatterns().getHintFor(url) * 10;
}

HeuristicScore::UrlRelationship
HeuristicScore::getDomainRelationship(
	char const* d1_begin, char const* d1_end,
	char const* d2_begin, char const* d2_end)
{
	chopLeadingWWW(d1_begin, d1_end);
	chopLeadingWWW(d2_begin, d2_end);
	
	if (StringUtils::ciEqual(d1_begin, d1_end, d2_begin, d2_end)) {
		return URLS_SAME_HOST;
	}
	
	// tomshardware.com vs tomshardware.de should return URLS_RELATED
	// Note that we consider a domain like .co.uk to be a toplevel one. 
	chopTopLevelDomain(d1_begin, d1_end);
	chopTopLevelDomain(d2_begin, d2_end);
	
	char const* p1 = d1_end;
	char const* p2 = d2_end;
	char const* p1_dot = d1_end;
	
	// walk from right to left while the characters are the same
	for (; p1 != d1_begin && p2 != d2_begin &&
	     tolower(static_cast<unsigned char>(p1[-1])) ==
	     tolower(static_cast<unsigned char>(p2[-1])); --p1, --p2) {
		if (p1[-1] == '.') {
			p1_dot = p1;
		}	
	}
	
	if ((p1 != d1_begin && p1[-1] != '.') || (p2 != d2_begin && p2[-1] != '.')) {
		// at least one of the positions is in the middle of a subdomain
		p1 = p1_dot;
	}
	
	return (p1 == d1_end ? URLS_UNRELATED : URLS_RELATED);
}

void
HeuristicScore::chopLeadingWWW(char const*& begin, char const* end)
{
	static char const www[] = {'w','w','w','.'};
	static char const* www_end = www + ARRAY_SIZE(www);
	
	if (StringUtils::ciStartsWith(begin, end, www, www_end)) {
		begin += ARRAY_SIZE(www);
	}
}

void
HeuristicScore::chopTopLevelDomain(char const* begin, char const*& end)
{
	char const* p1 = end;
	for (; p1 != begin && p1[-1] != '.'; --p1) {
		// search for a rightmost dot
	}
	if (p1 == begin) {
		// a dot wasn't found
		return;
	}
	
	if (end - p1 == 2) {
		char const* p2 = p1 - 1;
		for (; p2 != begin && p2[-1] != '.'; --p2) {
			// search for the next dot
		}
		if (end - p2 == 5 && p2 != begin) {
			// We consider domains like .co.uk to be toplevel
			end = p2 - 1; // p2 points *past* the dot
			return;
		} else if (end - p2 == 6 && p2 != begin) {
			// Some countries have domain system like .com.au, .org.au, etc
			BString const empty;
			BString l2_domain(empty, p2, p2 + 3);
			InsensitiveEqual ieq;
			if (ieq(l2_domain, BString("com")) ||
			    ieq(l2_domain, BString("org")) ||
			    ieq(l2_domain, BString("net"))) {
				// As for gov and mil, not handling them here
				// results in host1.gov.au and host2.gov.au
				// to be marked as URLS_RELATED, which seems fair.
				end = p2 - 1; // p2 points *past* the dot
				return;
			}
		}
	}
	
	end = p1 - 1; // p1 points *past* the dot
}

BString
HeuristicScore::extractLastPathComponent(BString const& path)
{
	char const* const begin = path.begin();
	
	char const* comp_end = path.end();
	for (; comp_end != begin && comp_end[-1] == '/'; --comp_end) {
		// skip trailing slashes
	}
	
	char const* comp_begin = comp_end;
	for (; comp_begin != begin && comp_begin[-1] != '/'; --comp_begin) {
		// skip non-slash characters
	}
	
	return BString(path, comp_begin, comp_end);
}

void
HeuristicScore::chopLastPathComponent(BString& path)
{
	char const* const begin = path.begin();
	char const* pos = path.end();
	
	for (; pos != begin && pos[-1] == '/'; --pos) {
		// skip trailing slashes
	}
	
	for (; pos != begin && pos[-1] != '/'; --pos) {
		// skip non-slash characters
	}
	
	path.trimBack(path.end() - pos);
}

BString
HeuristicScore::extractExtension(BString const& path_component)
{
	char const* const begin = path_component.begin();
	char const* const ext_end = path_component.end();
	char const* pos = ext_end;
	
	for (; pos != begin && pos[-1] != '.'; --pos) {
		// skip non-dot characters
	}
	
	
	char const* ext_begin = (pos == begin ? ext_end : pos);
	return BString(path_component, ext_begin, ext_end);
}

bool
HeuristicScore::extensionMatchesType(
	BString const& extension, ExpectedType const type)
{
	if (extension.empty()) {
		return false;
	}
	
	static BStringPOD const image_exts[] = {
		{ "gif" },
		{ "png" },
		{ "jpg" },
		{ "jpeg" }
	};
	static BStringPOD const flash_exts[] = {
		{ "swf" }
	};
	static BStringPOD const html_exts[] = {
		{ "html" },
		{ "htm" }
	};
	
	switch (type) {
	case EXPECT_ANY:
		return true;
	case EXPECT_IMAGE:
		return findMatchInsensitive(image_exts, ARRAY_SIZE(image_exts), extension);
	case EXPECT_FLASH:
		return findMatchInsensitive(flash_exts, ARRAY_SIZE(flash_exts), extension);
	case EXPECT_HTML:
		return findMatchInsensitive(html_exts, ARRAY_SIZE(html_exts), extension);
	}
	
	return false;
}

bool
HeuristicScore::findMatchInsensitive(
	BStringPOD const* records, size_t num_records, BString const& subject)
{
	InsensitiveEqual const ieq;
	for (size_t i = 0; i < num_records; ++i) {
		if (ieq(BString(records[i]), subject)) {
			return true;
		}
	}
	return false;
}

HeuristicScore::SizeRecord const*
HeuristicScore::findSizeRecord(int width, int height)
{
	SizeRecord rec = { width, height };
	SizeRecordComparator comp;
	SizeRecord const* end = m_sCommonSizes + ARRAY_SIZE(m_sCommonSizes);
	SizeRecord const* p = std::lower_bound(m_sCommonSizes, end, rec, comp);
	if (p != end && !comp(rec, *p)) {
		return p;
	}
	return 0;
}
