// SpamStochCheck - by Frank Thilo - thilo@unix-ag.org
// inspired by an idea from Paul Graham - http://www.paulgraham.com/spam.html
//
// checks a mail against a word database and calculates a SPAM probability
// using a Bayesian approach

#include <iostream>
#include <iomanip>
#include <fstream>
#include <vector>
#include <list>
#include <algorithm>
#include <numeric>
#include <stdexcept>
#include <cmath>
#include <clocale>

// these constants might be adjusted
const double SpamThreshold=0.995;           // consider SPAM when > thresh
const double NewWordProb=0.2;               // SPAM prob for new word

const unsigned int MinTokenLength=2;        // tokens must be at least this long
const unsigned int MaxTokenLength=20;       // ... and at most this long

const unsigned int NumInterestingTokens=15; // only consider top 15 tokens

// define types for our containers
typedef std::vector<std::string> wordlist_t;
typedef std::pair<std::string,double> token;
typedef std::list<token> besttokens_t;

wordlist_t WordList;      // sorted mail word list
besttokens_t BestTokens;  // sorted list of top tokens

std::string ProbFilename; // filename of database
bool verbose=false;       // are we in a takative mood?

///////////////////////////

// defines linear orders for tokens based on significance (farthest from 0.5)
bool tokenless(const token &t1,const token &t2)
{
  return fabs(t1.second-0.5) < fabs(t2.second-0.5);
}

///////////////////////////

// read one entry (word, prob) from database probfile
void GetNextWord(std::istream &probfile,char *word,int &prob)
{
  probfile.getline(word,MaxTokenLength+1,' ');
  prob=(probfile.get()-'0')*100;
  prob+=(probfile.get()-'0')*10;
  prob+=(probfile.get()-'0');
  probfile.ignore();
}

///////////////////////////

// determine the NumInterestingTokens most significant tokens
void GetBestTokensFromMail()
{
  std::ifstream probfile(ProbFilename.c_str());
  if (!probfile)
    throw std::runtime_error("Could not read file: "+ProbFilename);

  unsigned int count=0;

  // iterate over sorted mail word list
  char word[MaxTokenLength+1]="";
  for (wordlist_t::const_iterator i=WordList.begin();i!=WordList.end();++i)
  {
    const std::string &mailword=*i;
    int p,comp;

    // read words from database until match or too far
    while ((comp=strcmp(word,mailword.c_str()))<0 && probfile)
      GetNextWord(probfile,word,p);

    double prob= comp==0 ? p/1000.0 : NewWordProb; // match?

    // put into list if among 15 most significant tokens
    token token(mailword,prob);
    if (tokenless(BestTokens.front(),token) || count<NumInterestingTokens)
    {
      besttokens_t::iterator it=
        lower_bound(BestTokens.begin(),BestTokens.end(),token,tokenless);
      BestTokens.insert(it,token);
      if (count++>=NumInterestingTokens)
        BestTokens.pop_front();
    }
    // we had a match, so read next word (dupe handling)
    if (comp==0)
      GetNextWord(probfile,word,p);
  }
}

///////////////////////////

// defines which characters make up the tokens
bool istokenchar(char c)
{
  return isalnum(c) || c=='-' || c=='$' || c=='\'';
};

///////////////////////////

int GetToken(std::istream &input,char* word)
{
  unsigned int i;
  // read in characters until we've got one token
  for (i=0;i<MaxTokenLength+1;i++)
  {
    char c=input.get();
    if (input && istokenchar(c))
      word[i]=c;
    else
    {
      word[i]=0;
      break;
    }
  }
  return i;
}

///////////////////////////

void SkipToken(std::istream &input)
{
  char c;
  do
    c=input.get();
  while (istokenchar(c) && input);
}

///////////////////////////

void SkipNonToken(std::istream &input)
{
  char c;
  do
    c=input.get();
  while (!istokenchar(c) && input);
  input.putback(c);
}

///////////////////////////

// creates a sorted list of tokens in the mail
void MakeWordlist(std::istream &input)
{
  while (input)
  {
    char word[MaxTokenLength+1];

    unsigned int length=GetToken(input,word);

    if (length>MaxTokenLength)           // token too long?
      SkipToken(input);
    else if (length>=MinTokenLength)     // we got a valid token, store it
      WordList.push_back(word);

    SkipNonToken(input);                 // skip non-token chars
  }
  sort(WordList.begin(),WordList.end()); // finally sort our word list
}

///////////////////////////

// calculate SPAM probability from word probabilities (Bayesian)
double CalcSpamProb()
{
  double P=1.0;
  for (besttokens_t::iterator i=BestTokens.begin();i!=BestTokens.end();++i)
    P*=i->second;

  double Pinv=1.0;
  for (besttokens_t::iterator i=BestTokens.begin();i!=BestTokens.end();++i)
    Pinv*=1.0-i->second;

  double prob=P/(P+Pinv);
  return prob;
}

///////////////////////////

void PrintProbabilities(double spamprob)
{
  for (besttokens_t::const_iterator i=BestTokens.begin();
       i!=BestTokens.end();++i)
  {
    std::cout << i->first.c_str() << ' ' << i->second*100 << "%\n";
  }
  std::cout << "Spam probability: " << std::setiosflags(std::ios::fixed) <<
    std::setprecision(1) << spamprob*100 << "%\n" << std::endl;
}

///////////////////////////

void Usage(char *argv0)
{
  std::cerr << "Usage: " << argv0 <<
    " [-v] [-l <locale>] -p <prob file>\n";
  std::cerr << "reads mail from stdin\n";
  std::cerr << "returns 0=good mail 1=error 2=spam" << std::endl;
}

///////////////////////////

void ParseArgs(int argc,char *argv[])
{
  enum ParseState {none,prob,locale} state=none;
  for (int i=1;i<argc;i++)
  {
    if (argv[i][0]=='-')
    {
      if (argv[i][1]==0 || argv[i][2]!=0)
        throw std::runtime_error(std::string(argv[0])+": invalid option -- "+argv[i]);
      state=none;
      switch (argv[i][1])
      {
        case 'p':
          state=prob;
          break;

        case 'l':
          state=locale;
          break;

        case 'v':
          verbose=true;
          break;

        case 'h':
          Usage(argv[0]);
          exit(0);

        default:
          throw std::runtime_error(std::string(argv[0])+": invalid option -- "+argv[i]);
      }
    }
    else
    {
      switch (state)
      {
        case prob:
          ProbFilename=argv[i];
          state=none;
          break;

        case locale:
          if (!setlocale(LC_CTYPE,argv[i]))
            throw std::runtime_error(std::string(argv[0])+": could not set locale -- "+argv[i]);
          state=none;
          break;

        default:
          throw std::runtime_error("unexpected -- "+std::string(argv[i]));
      }
    }
  }
  if (ProbFilename=="")
    throw std::runtime_error("you have to specify the probabilities file with -p");
}


///////////////////////////

// return 0 for good mail, 2 for SPAM and 1 in case of error
int main(int argc,char *argv[])
{
  std::ios::sync_with_stdio(false); // speeds up I/O for some implementations
  setlocale(LC_CTYPE,"");

  try
  {
    ParseArgs(argc,argv);
  }
  catch (std::exception &e)
  {
    std::cerr << e.what() << std::endl;
    std::cerr << "Try '" << argv[0] << " -h' for more information" << std::endl;
    exit(1);
  }

  try
  {
    MakeWordlist(std::cin);
    GetBestTokensFromMail();
    double spamprob=CalcSpamProb();

    if (verbose)
      PrintProbabilities(spamprob);
    return spamprob > SpamThreshold ? 2 : 0;
  }
  catch (std::exception &e)
  {
    std::cout << "Fatal Error: " << e.what() << std::endl;
    exit(1);
  }
}

