/* 
   elmo - ELectronic Mail Operator

   Copyright (C) 2003, 2004 rzyjontko, University of Wroclaw

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; version 2.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software Foundation,
   Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.  

   ----------------------------------------------------------------------

   This file contains implementation of Bayesian Mail Filter, which is
   coverd by Paul Graham's article at
   http://www.paulgraham.com/spam.html
   and
   http://www.paulgraham.com/better.html

   This module was written for course of Artifficial Intelligence at
   University of Wroclaw, Institute of Computer Science.

*/
/****************************************************************************
 *    IMPLEMENTATION HEADERS
 ****************************************************************************/

#include <string.h>
#include <stdio.h>
#include <errno.h>

#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>

#include "hash.h"
#include "bayes.h"
#include "mail.h"
#include "wrapbox.h"
#include "xmalloc.h"
#include "ask.h"
#include "error.h"
#include "token.h"
#include "rstring.h"
#include "debug.h"
#include "file.h"

/****************************************************************************
 *    IMPLEMENTATION PRIVATE DEFINITIONS / ENUMERATIONS / SIMPLE TYPEDEFS
 ****************************************************************************/

#define PROB_CERTAIN    10000
#define PROB_NEUTRAL     4000
#define PROB_LIMIT       9000
#define PROB_HALF        4900
#define PROB_MIN           10
#define PROB_MAX         9900
#define PROB_MAX_MAX     9999
#define PROB_MIN_OCCUR      5

#define WORD_MAX_LEN 60

#define MAX(a,b)   (((a) < (b)) ? (b) : (a))

#define HEAP_SIZE  16

#define PREAMBLE do { if (!initialized) bayes_init (); } while (0)

/****************************************************************************
 *    IMPLEMENTATION PRIVATE CLASS PROTOTYPES / EXTERNAL CLASS REFERENCES
 ****************************************************************************/
/****************************************************************************
 *    IMPLEMENTATION PRIVATE STRUCTURES / UTILITY CLASSES
 ****************************************************************************/
/****************************************************************************
 *    IMPLEMENTATION REQUIRED EXTERNAL REFERENCES (AVOID)
 ****************************************************************************/
/****************************************************************************
 *    IMPLEMENTATION PRIVATE DATA
 ****************************************************************************/

/**
 * Many people don't use anti-spam filters.  Let them save memory space.
 * This variable is used in PREABLE macro -- see above.
 */
static int initialized = 0;

/**
 * Hashing tables storing good and bad words (mapped to number of
 * occurences) and probability table which maps words to their bayesian
 * probability.
 */
static htable_t *good_table = NULL;
static htable_t *bad_table  = NULL;
static htable_t *prob_table = NULL;

/**
 * These variables are very important.  They are used to calculate
 * aposteriori probability for a word, that it is a part of spam message.
 * They hold the count of _scanned_ messages from each corpus.
 */
static int good_messages = 0;
static int bad_messages  = 0;

/**
 * This is probably the fastest way to obtain 15 most interesting words.
 * This is a static-size heap which stores 16 elements.  Top element is
 * the _least_ interesting word.  The idea is to put all the words into
 * the heap (one by one) and heapify after every such operation.  In the
 * end, there will be only 15 most interesting words in the heap.
 * Heapify operation runs in constant time (because of heap's constant
 * size), so finding 15 most interesting words runs in linear time.
 */
static entry_t  *word_heap[HEAP_SIZE];

/**
 * This is used when dumping table.  Function that is used as iterator
 * writes to this file.
 */
static FILE *print_list_fp = NULL;


/**
 * This is usec to indicate how to change the value of the word.  When
 * scanning message collect_change = +1, thus the function will increase
 * word count.  When unscanning message collect_change = -1, thus the
 * function will decrease word count.
 */
static int collect_change = 0;


/****************************************************************************
 *    INTERFACE DATA
 ****************************************************************************/
/****************************************************************************
 *    IMPLEMENTATION PRIVATE FUNCTION PROTOTYPES
 ****************************************************************************/

static void destroy_int (void *a);
static int  list_len (entry_t *l);
static void print_list (entry_t *l);
static void print_list_txt (FILE *fp, entry_t *l);
static void dump_table (FILE *fp, htable_t *table);
static void dump_tables_to_file (char *fname);
static void dump_table_txt (FILE *fp, htable_t *table);
static void load_table (FILE *fp, htable_t **table);
static void load_from_file (char *fname);

static int  heap_value (int i);
static void heap_empty (void);
static int  heap_contains (entry_t *entry);
static int  heap_internal_value (int i);
static int  heap_bottom (int i);
static int  heap_left_child (int i);
static int  heap_right_child (int i);
static void heap_exchange (int i, int j);
static void heapify (void);

static void scan_body (char *buf, void (*fun)(char *));
static int  leave_header (char *buf, int leave_rest);
static void scan_header (char *buf, void (*fun)(char *));

static void collect_insert_bad (char *word);
static void collect_insert_good (char *word);
static void collect_message (mail_t *mail, int is_spam, int change);

static int  prob_word (char *word);
static int  prob_heap (void);
static void prob_insert_word (entry_t *entry);
static void prob_create_table (void);

static void verify_word (char *word);
static void verify_message (mail_t *mail);
static int  verify_is_spam (mail_t *mail);

/****************************************************************************
 *    LOADING AND STORING DATA FUNCTIONS
 ****************************************************************************/

static void
destroy_int (void *a)
{}


static int
list_len (entry_t *l)
{
        int len = 0;
        while (l){
                len++;
                l = l->next;
        }
        return len;
}



static void
print_list (entry_t *l)
{
        int len;
  
        if (print_list_fp == NULL)
                return;
  
        len = strlen (l->key);
        fwrite (&len, 1, sizeof (len), print_list_fp);
        fwrite (l->key, 1, len + 1, print_list_fp);
        fwrite (&l->content, 1, sizeof (l->content), print_list_fp);
}



static void
print_list_txt (FILE *fp, entry_t *l)
{
        while (l){
                fprintf (fp, "%d %s %d ", strlen (l->key), l->key, (int)l->content);
                l = l->next;
        }
}



static void
dump_table (FILE *fp, htable_t *table)
{
        print_list_fp = fp;
        fwrite (&table->exponent, 1, sizeof (table->exponent), fp);
        fwrite (&table->count, 1, sizeof (table->count), fp);
        htable_iterator (table, print_list);
        print_list_fp = NULL;
}



static void
dump_tables_to_file (char *fname)
{
        FILE *fp;
  
        fp = fopen (fname, "w");
        if (fp == NULL){
                error_ (errno, "%s", fname);
                return;
        }

        fwrite (&good_messages, 1, sizeof (good_messages), fp);
        fwrite (&bad_messages, 1, sizeof (bad_messages), fp);
  
        dump_table (fp, good_table);
        dump_table (fp, bad_table);

        fclose (fp);
}



static void
dump_table_txt (FILE *fp, htable_t *table)
{
        int i;
        int len;
        int size = 1 << table->exponent;
  
        fprintf (fp, "good_messages: %d\n", good_messages);
        fprintf (fp, "bad_messages:  %d\n", bad_messages);
        fprintf (fp, "%d\n", table->exponent);
        for (i = 0; i < size; i++){
                len = list_len (table->array[i]);
                fprintf (fp, "%2d ", len);
                print_list_txt (fp, table->array[i]);
                fprintf (fp, "\n");
        }
}



static void
load_table (FILE *fp, htable_t **table)
{
        int   i;
        int   size;
        int   count;
        int   key_len;
        char *key = xmalloc (WORD_MAX_LEN + 2);
        int   content;

        fread (&size, 1, sizeof (size), fp);
        fread (&count, 1, sizeof (count), fp);
        *table = htable_create (size);
        for (i = 0; i < count; i++){
                fread (&key_len, 1, sizeof (key_len), fp);
                fread (key, 1, key_len + 1, fp);
                fread (&content, 1, sizeof (content), fp);
                htable_insert (*table, key, (void *) content);
        }
        xfree (key);
}



static void
load_from_file (char *fname)
{
        FILE *fp;

        if (fname == NULL || *fname == '\0')
                return;

        fp = fopen (fname, "r");
        if (fp == NULL){
                error_ (errno, "%s", fname);
                return;
        }

        fread (&good_messages, 1, sizeof (good_messages), fp);
        fread (&bad_messages, 1, sizeof (bad_messages), fp);
  
        load_table (fp, &good_table);
        load_table (fp, &bad_table);

        fclose (fp);
}

/****************************************************************************
 *    HEAP FUNCTIONS
 ****************************************************************************/

static int
heap_value (int i)
{
        if (word_heap[i] == NULL)
                return PROB_NEUTRAL;

        return (int) word_heap[i]->content;
}


static void
heap_empty (void)
{
        int i;

        for (i = 0; i < HEAP_SIZE; i++){
                word_heap[i] = NULL;
        }
}


static int
heap_contains (entry_t *entry)
{
        int i;

        for (i = 0; i < HEAP_SIZE; i++){
                if (word_heap[i] == entry)
                        break;
        }
        return i < HEAP_SIZE;
}


static int
heap_internal_value (int i)
{
        if (word_heap[i] == NULL)
                return 0;

        if (((int) word_heap[i]->content) > PROB_HALF)
                return ((int) word_heap[i]->content) - PROB_HALF;
        else
                return PROB_HALF - (int) word_heap[i]->content;
}



static int
heap_bottom (int i)
{
        return (i << 1) > (HEAP_SIZE - 1);
}


static int
heap_left_child (int i)
{
        return i << 1;
}


static int
heap_right_child (int i)
{
        return (i << 1) | 1;
}


static void
heap_exchange (int i, int j)
{
        entry_t *entry = word_heap[i];
        word_heap[i]   = word_heap[j];
        word_heap[j]   = entry;
}


static void
heapify (void)
{
        int node_val;
        int l_child_val;
        int r_child_val;
        int left_child;
        int right_child;
        int node = 1;

        node_val    = heap_internal_value (0);
        l_child_val = heap_internal_value (1);

        if (node_val <= l_child_val)
                return;

        heap_exchange (0, 1);
        while (1){

                if (heap_bottom (node))
                        break;
                left_child  = heap_left_child (node);
                right_child = heap_right_child (node);
                l_child_val = heap_internal_value (left_child);
                r_child_val = heap_internal_value (right_child);
                node_val    = heap_internal_value (node);
                if (r_child_val < l_child_val){
                        left_child  = right_child;
                        l_child_val = r_child_val;
                }
                if (node_val <= l_child_val)
                        break;
                heap_exchange (node, left_child);
                node = left_child;
        }
}

/****************************************************************************
 *    SCANNING FUNCTIONS
 ****************************************************************************/

static void
scan_body (char *buf, void (*fun)(char *))
{
        int   td = token_open (buf, WORD_MAX_LEN, TOKEN_BAYES_WORD_RE, 1);
        char *word;

        if (td == -1)
                return;
  
        while (1){
                word = token_read_next (td);
                if (word == NULL)
                        break;
                fun (word);
        }

        token_close (td);
}


static int
leave_header (char *buf, int leave_rest)
{
        switch (*buf){

                /**
                 * Multiline header continuation.
                 */
                case ' ': case '\t':
                        return leave_rest;

                        /**
                         * Scan Received and ignore the rest (Resent*).
                         */
                case 'R':
                        if (strstr (buf, "Received"))
                                return 0;
                        return 1;
      
                        /**
                         * Scan From, Sender, To and Subject.
                         */
                case 'F': case 'S': case 'T':
                        return 0;

                        /**
                         * Ignore any others.
                         */
                default:
                        return 1;
        }
}


static void
scan_header (char *buf, void (*fun)(char *))
{
        rstring_t *lines;
        int        i;
        int        leave_rest = 0;
        char      *seek;

        lines = rstring_split (buf, "\n");
        for (i = 0; i < lines->count; i++){
                seek = strchr (lines->array[i], ':');
                if (seek == NULL)
                        continue;
                *seek = '\0';
                if (leave_header (lines->array[i], leave_rest)){
                        leave_rest = 1;
                        continue;
                }
                leave_rest = 0;
                scan_body (seek + 1, fun);
        }
        rstring_delete (lines);
}

/****************************************************************************
 *    COLLECTING DATA FUNCTIONS
 ****************************************************************************/

static void
collect_insert_bad (char *word)
{
        entry_t *entry;

        entry = htable_insert (bad_table, word, (void *) 0);
        entry->content = (void *) ((int) entry->content + collect_change);

        if ((int) entry->content < 0)
                entry->content = (void *) 0;
}



static void
collect_insert_good (char *word)
{
        entry_t *entry;
  
        entry = htable_insert (good_table, word, (void *) 0);
        entry->content = (void *) ((int) entry->content + collect_change);

        if ((int) entry->content < 0)
                entry->content = (void *) 0;
}


static void
collect_message (mail_t *mail, int is_spam, int change)
{
        char  *buffer = NULL;
        int    len;
        str_t *str;
        void (*fun)(char *);

        collect_change = change;
  
        if (is_spam)
                fun = collect_insert_bad;
        else
                fun = collect_insert_good;

        if (wrapbox_mail_header (mail, &buffer)){
                return;
        }

        len    = strlen (buffer);
        buffer = mime_decode_header (buffer, len, 1);
        
        scan_header (buffer, fun);
        xfree (buffer);
        buffer = NULL;

        str = wrapbox_mail_body (mail, NULL, 1);
        if (str){
                scan_body (str->str, fun);
                str_destroy (str);
        }
}

/****************************************************************************
 *    PROBABILITY FUNCTIONS
 ****************************************************************************/

static int
prob_word (char *word)
{
        entry_t *good_e     = htable_lookup (good_table, word);
        entry_t *bad_e      = htable_lookup (bad_table, word);
        int      good_occur = 0;
        int      bad_occur  = 0;
        int      result;

        if (good_e)
                good_occur = 2 * (int) good_e->content;
        if (bad_e)
                bad_occur = (int) bad_e->content;

        if (bad_occur + good_occur < PROB_MIN_OCCUR)
                return PROB_NEUTRAL;

        if (good_occur == 0){
                if (bad_occur > 10)
                        return PROB_MAX_MAX;
                else
                        return PROB_MAX;
        }

        if (bad_occur == 0)
                return PROB_MIN;

        result = PROB_CERTAIN * bad_occur / bad_messages;
        result = PROB_CERTAIN * result
                / (result + PROB_CERTAIN * good_occur / good_messages);

        if (result > PROB_MAX_MAX)
                result = PROB_MAX_MAX;

        if (result < 0)
                result = PROB_MIN;

        return result;
}



static int
prob_heap (void)
{
        float prod     = 1.0;
        float opp_prod = 1.0;
        int   i;
        int   result;

        for (i = 1; i < HEAP_SIZE; i++){
                if (word_heap[i])
                        debug_msg (DEBUG_INFO, "    %s: %d", word_heap[i]->key,
                                   (int) word_heap[i]->content);
                prod     *= (heap_value (i) / (float) PROB_CERTAIN);
                opp_prod *= 1.0 - (heap_value (i) / (float) PROB_CERTAIN);
        }
        result = PROB_CERTAIN * (prod / (prod + opp_prod));
        debug_msg (DEBUG_INFO, "message prob: .%d", result);
        return result;
}




static void
prob_insert_word (entry_t *entry)
{
        int value = prob_word (entry->key);

        htable_insert (prob_table, entry->key, (void *) value);
}



static void
prob_create_table (void)
{
        if (prob_table)
                htable_destroy (prob_table, destroy_int);
        prob_table = htable_create (MAX (good_table->exponent, bad_table->exponent)
                                    + 1);
  
        htable_iterator (good_table, prob_insert_word);
        htable_iterator (bad_table, prob_insert_word);
}


/****************************************************************************
 *    DISCRIMINATION FUNCTIONS
 ****************************************************************************/


static void
verify_word (char *word)
{
        entry_t *entry;

        entry = htable_insert (prob_table, word, (void *) PROB_NEUTRAL);

        if (heap_contains (entry))
                return;

        word_heap[0] = entry;
        heapify ();
}


static void
verify_message (mail_t *mail)
{
        char  *buffer = NULL;
        int    len;
        str_t *str;

        heap_empty ();
  
        if (wrapbox_mail_header (mail, &buffer)){
                return;
        }

        len    = strlen (buffer);
        buffer = mime_decode_header (buffer, len, 1);
        scan_header (buffer, verify_word);
        xfree (buffer);
        buffer = NULL;

        str = wrapbox_mail_body (mail, NULL, 1);

        if (str){
                scan_body (str->str, verify_word);
                str_destroy (str);
        }
}



static int
verify_is_spam (mail_t *mail)
{
        if (prob_table == NULL)
                prob_create_table ();
        verify_message (mail);
        return prob_heap () > PROB_LIMIT;
}

/****************************************************************************
 *    INTERFACE FUNCTIONS
 ****************************************************************************/

void
bayes_init (void)
{
        char *fname;

        if (initialized)
                return;
  
        fname = ask_for_default ("bayes_file", NULL);

        if (fname != NULL && *fname != '\0')
                load_from_file (fname);

        if (good_table == NULL)
                good_table = htable_create (10);
        if (bad_table == NULL)
                bad_table = htable_create (10);

        initialized = 1;
}



void
bayes_free_resources (void)
{
        char *fname;

        if (!initialized)
                return;
  
        fname = ask_for_default ("bayes_file", NULL);

        if (fname)
                dump_tables_to_file (fname);
  
        if (good_table){
                htable_destroy (good_table, destroy_int);
                good_table = NULL;
        }
        if (bad_table){
                htable_destroy (bad_table,  destroy_int);
                bad_table  = NULL;
        }
        if (prob_table){
                htable_destroy (prob_table, destroy_int);
                prob_table = NULL;
        }

        initialized = 0;
}



void
bayes_dump_tables (void)
{
        char *fname;

        PREAMBLE;
  
        fname = ask_for_simple ("bayes_file");

        if (fname == NULL || *fname == '\0')
                return;

        dump_tables_to_file (fname);
}



void
bayes_load_tables (void)
{
        char *fname= ask_for_simple ("bayes_file");

        if (fname == NULL || *fname == '\0')
                return;

        load_from_file (fname);
}



void
bayes_dump_tables_txt (void)
{
        char *fname;
        FILE *fp;
        char  ret;
  
        PREAMBLE;
  
        fname = ask_for_simple ("bayes_file_txt");
  
        if (fname == NULL || *fname == '\0' || *fname == '\n')
                return;
  
        fp = file_open (fname, "w", O_WRONLY | O_CREAT | O_EXCL, 0644);
        if (fp == NULL){
                if (errno != EEXIST){
                        error_ (errno, "%s", fname);
                        return;
                }
    
                fname = xstrdup (fname);
    
                ret = ask_if_sure ("file already exists, overwrite? ");

                switch (ret){
      
                        case -1:
                        case 0:
                                xfree (fname);
                                return;

                        case 1:
                                fp = file_open (fname, "w", O_WRONLY | O_CREAT, 0644);
          
                                if (fp == NULL){
                                        error_ (errno, "%s", fname);
                                        xfree (fname);
                                        return;
                                }
                                xfree (fname);
                }
        }
  
        dump_table_txt (fp, good_table);
        dump_table_txt (fp, bad_table);
  
        fclose (fp);
}



void
bayes_scan_spam (mail_t *mail)
{
        PREAMBLE;
  
        collect_message (mail, 1, +1);
        bad_messages++;
}



void
bayes_scan_legitimate (mail_t *mail)
{
        PREAMBLE;
  
        collect_message (mail, 0, +1);
        good_messages++;
}



void
bayes_unscan_spam (mail_t *mail)
{
        PREAMBLE;

        if (bad_messages == 0)
                return;
  
        collect_message (mail, 0, -1);
        bad_messages--;
}



void
bayes_unscan_legitimate (mail_t *mail)
{
        PREAMBLE;

        if (good_messages == 0)
                return;

        collect_message (mail, 0, -1);
        good_messages--;
}



int
bayes_is_spam (mail_t *mail)
{
        PREAMBLE;
  
        return verify_is_spam (mail);
}


/****************************************************************************
 *    INTERFACE CLASS BODIES
 ****************************************************************************/
/****************************************************************************
 *
 *    END MODULE bayes.c
 *
 ****************************************************************************/
