/*
   $Id: bact_learn.cpp,v 1.1.1.1 2004/06/23 05:00:42 taku-ku Exp $;

   Copyright (C) 2003 Taku Kudo, All rights reserved.
   This is free software with ABSOLUTELY NO WARRANTY.

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.
  
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
  
   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
   02111-1307, USA
*/
 
#include <iostream>
#include <cmath>
#include <fstream>
#include <vector>
#include <string>
#include <strstream>
#include <algorithm>
#include <numeric>
#include <functional>
#include <map>
#include <set>
#include <iterator>
#include <cstdlib>
#include <cstring>
#include <unistd.h>
#include "common.h"

using namespace std;

class BactLearner {

private:
  
  struct rule_t {
    double gain;
    unsigned int size;
    std::string  subtree;
    std::vector <unsigned int> loc;
    friend bool operator < (const rule_t &r1, const rule_t &r2) 
    {
       return r1.subtree < r2.subtree; 
    }
  };

  struct space_t {
    int prefix;
    int last;
    space_t *prev;     
    stx::string_symbol rme;
    stx::string_symbol rmed;
    std::vector <int>       loc;
    std::vector <space_t *> next;

    void add (unsigned int i, int j) 
    {
      if (last != (int)i) loc.push_back (-(int)(i+1));
      loc.push_back (j);
      last = (int)i;
    }

    void shrink () 
    {
      std::vector<int> tmp;
      for (unsigned int i = 0; i < loc.size(); ++i) 
	if (loc[i] < 0) tmp.push_back (loc[i]);
      loc = tmp;
      std::vector<int>(loc).swap (loc); // shrink
      last = -1;
    }

    unsigned int support ()
    {
      unsigned int result = 0;
      for (unsigned int i = 0; i < loc.size(); ++i) 
	if (loc[i] < 0) ++result;
      return result;
    }

    space_t(): prefix(0), last(-1), prev(0) {};
  };

  stx::string_symbol                   bc_mark;
  std::vector < std::vector <node_t> > transaction;
  std::vector < int >                  y;
  std::vector < double >               w;

  std::set < rule_t >           rule_cache;
  std::set <stx::string_symbol> single_node_minsup_cache;
  std::set <stx::string_symbol> single_node_cache;
  std::map <stx::string_symbol, std::set <stx::string_symbol> > double_node_cache;

  double       wbias;
  rule_t       rule;
  double       tau; 
  unsigned int maxpat;
  unsigned int minsup;
  unsigned int total;
  unsigned int pruned;
  unsigned int rewritten;

  bool read (const char *filename)
  {
    char line [8192 * 128];
    char *column[5];

    std::ifstream ifs (filename);
    if (! ifs) return false;

    while (ifs.getline (line, 8192 * 128)) {
      if (line[0] == '\0' || line[0] == ';') continue;
      transaction.resize (transaction.size()+1);
       if (2 != tokenize (line, "\t ", column, 2)) {
	 std::cerr << "FATAL: Format Error: " << line << std::endl;
	 return false;
      }

      int _y = atoi (column[0]);
      str2node (column[1], transaction[transaction.size()-1]);
      y.push_back (_y > 0 ? +1 : -1);
    }

    return true;
  }

  double calc_gain (const std::vector <unsigned int> &loc) 
  {
    double gain = - wbias;
    for (unsigned int i = 0; i < loc.size(); ++i) 
      gain += 2 * y[loc[i]] * w[loc[i]];
    return gain;
  }

  bool can_prune (space_t *space, unsigned int size)
  {
    stx::string_symbol& rme = space->rme;

    ++total;

    if (size >= 2 && single_node_cache.find (rme) == single_node_cache.end()) {
      ++pruned; 
      return true; 
    }

    if (size >= 3 && double_node_cache[space->rmed].find (rme) == double_node_cache[space->rmed].end()) {
       ++pruned; 
      return true;
    }

    double upos = -wbias;
    double uneg = wbias;
    double gain = -wbias;
    unsigned int support = 0; 

    std::vector <int>& loc = space->loc;

    for (unsigned int i = 0; i < loc.size(); ++i) {
      if (loc[i] >= 0) continue;
      ++support; 
      unsigned int j = (unsigned int)(-loc[i]) - 1;
      gain += 2 * y[j] * w[j];
      if (y[j] > 0) upos += 2 * w[j];
      else          uneg += 2 * w[j];
    }
     
    if (support < minsup || std::max (upos, uneg) <= tau) {
      ++pruned;
      return true;
    }

    double g = std::abs (gain);

    if (size == 1) single_node_cache.insert (space->rme);
    if (size == 2) double_node_cache[space->prev->rme].insert (space->rme);

    if (g > tau || g == tau && size < rule.size) { 

      ++rewritten;

      tau = g;
      rule.gain = gain;
      rule.size = size;

      std::ostrstream ostrs;
      std::vector <stx::string_symbol> pattern;
      for (space_t *t = space; t; t = t->prev) {
	 pattern.push_back (t->rme);
	 for (int d = 0; d < t->prefix; ++d) pattern.push_back (bc_mark);
      }
      std::reverse (pattern.begin (), pattern.end());
      std::copy (pattern.begin (), pattern.end()-1, std::ostream_iterator<stx::string_symbol>(ostrs, " "));
      ostrs << pattern[pattern.size()-1] << std::ends;
      rule.subtree = ostrs.str ();

      rule.loc.clear ();
      for (unsigned int i = 0; i < space->loc.size(); ++i)
	if (space->loc[i] < 0) rule.loc.push_back ((unsigned int)(-space->loc[i]) - 1);

    }
    
    return false;
  }

  void span (space_t *space, std::vector<space_t *>& new_space, unsigned int size)
  {
    std::vector <space_t *>& next = space->next;
     
    if (next.size() == 1 && next[0] == 0) {
       return;

    } else if (! next.empty()) {

      for (std::vector<space_t *>::iterator it = next.begin(); it != next.end(); ++it) 
	if (! can_prune (*it, size)) new_space.push_back ((*it));
      
    } else  { 

      // obtain the list of nodes on Right-Most-Path
      std::vector <stx::string_symbol> rmnode;
      int n = 0;
      for (space_t *t = space; t; t = t->prev) {
	 if (n == 0) rmnode.push_back (t->rme);
	 else --n;
	 n += t->prefix;
      }

      int depth = rmnode.size();
      unsigned int id = 0;    
      std::map <int, map<stx::string_symbol, space_t > > candidate;
      std::vector <int>& loc = space->loc;

      for (unsigned int i = 0; i < loc.size(); ++i) {

	if (loc[i] < 0) {
	  id = (unsigned int)(-loc[i]) - 1;
	  continue;
	}

	int pos = loc[i];

	for (int prefix = 0 ; prefix < depth && pos != -1; ++prefix) {
	  int start = (prefix == 0) ? transaction[id][pos].child : transaction[id][pos].sibling;
	  for (int l = start; l != -1; l = transaction[id][l].sibling) {
	     if (minsup == 1 || single_node_minsup_cache.find (transaction[id][l].val) != single_node_minsup_cache.end())
	       candidate[prefix][transaction[id][l].val].add (id, l);
	  }
	  if (prefix != 0) pos = transaction[id][pos].parent;
	}
      }

      for (std::map <int, map<stx::string_symbol, space_t > >::iterator it1 = candidate.begin();
	   it1 != candidate.end(); ++it1) {

	for (std::map <stx::string_symbol, space_t >::iterator it2 = it1->second.begin();
	     it2 != it1->second.end(); ++it2) {

	  space_t *c = new space_t;
	  c->loc = it2->second.loc;
	  std::vector<int>(c->loc).swap (c->loc);
	  c->prefix = it1->first;
	  c->rme    = it2->first;
	  c->rmed   = rmnode.empty() ? it2->first : rmnode[c->prefix];
	  c->prev   = space;
	  c->next.clear ();
	  next.push_back (c);

	  if (! can_prune (c, size)) new_space.push_back (c);
	}
      }

      space->shrink ();

      if (next.empty()) next.push_back (0);
       
      std::vector<space_t *>(next).swap (next);
    }

    return;
  }

public:

  bool run (const char *in, 
	    const char *out,
	    unsigned int _maxpat,
	    unsigned int maxitr,
	    unsigned int _minsup,
	    unsigned int _subitr,
	    double prob,
	    int type )
  { 
    maxpat = _maxpat;
    minsup = _minsup;

    bc_mark = ")";

    if (! read (in)) {
      std::cerr << "FATAL: Cannot open input file: " << in << std::endl;
      return false;
    }

    std::ofstream os (out);
    if (! os) {
      std::cerr << "FATAL: Cannot open output file: " << out << std::endl;
      return false;
    }

    std::cout.setf(std::ios::fixed,std::ios::floatfield);
    std::cout.precision(8);

    os.setf(std::ios::fixed,std::ios::floatfield);
    os.precision(12);

    unsigned int l   = transaction.size();
    double alpha_sum = 0.0;
    double margin    = 0.0;
    tau              = 0.0;
    wbias            = 0.0;

    w.resize (l);
    std::fill (w.begin(), w.end(), 1.0/l);

    std::vector <double>  F (l);
    std::fill (F.begin(), F.end(), 0.0);
    std::vector <int> result (l);

    rule.subtree = "";
    rule.loc.clear ();
    rule.size = 0;

    for (unsigned int i = 0; i < l; ++i) wbias += y[i] * w[i];

    std::map <stx::string_symbol, space_t> seed;
    for (unsigned int i = 0; i < l; ++i) {
      for (unsigned int j = 0; j < transaction[i].size(); ++j) {
	space_t & tmp = seed[transaction[i][j].val];
	tmp.add (i,j);
	tmp.prefix = 0;
	tmp.next.clear ();
	tmp.rme = transaction[i][j].val;
	tmp.prev = 0;
      }
    }

    for (std::map <stx::string_symbol, space_t>::iterator it = seed.begin (); it != seed.end(); ++it) {
       if (it->second.support() < minsup) seed.erase (it++);
       else  single_node_minsup_cache.insert (it->second.rme);
    }

    std::vector <space_t *> old_space;
    std::vector <space_t *> new_space;

    unsigned int subitr = (unsigned int)(1.0 * _subitr * prob);

    for (unsigned int itr = 0; itr < maxitr; ++itr) {

      pruned = total = rewritten = 0;

      if ((itr % _subitr) < subitr) {

	single_node_cache.clear();
	double_node_cache.clear();
	old_space.clear ();
	new_space.clear ();

	for (std::map <stx::string_symbol, space_t>::iterator it = seed.begin (); it != seed.end(); ++it) 
	  if (! can_prune (&(it->second), 1)) old_space.push_back (&(it->second));

	for (unsigned int size = 2; size <= maxpat; ++size) {
	  for (unsigned int i = 0; i < old_space.size(); ++i) span (old_space[i], new_space, size);
	  if (new_space.empty()) break;
	  old_space = new_space;
	  new_space.clear ();
	}
      } 

      rule_cache.insert (rule);
      
      if (margin <= -0.7) margin = -0.7;
      if (type) margin = 0.0;
      double alpha = log ((1 + std::abs(rule.gain))/(1 - std::abs(rule.gain)))/2.0 + log ((1 - margin) / (1 + margin)) / 2.0; 

      int _y = rule.gain > 0 ? +1 : -1; // class h<t,y> 
      std::fill (result.begin (), result.end(), -_y);
      for (unsigned int i = 0; i < rule.loc.size(); ++i) result[rule.loc[i]] = _y;

      unsigned int error_num = 0;
      margin = 1e+37;
      alpha_sum += alpha;
      for (unsigned int i = 0; i < l;  ++i) {
	F[i] += alpha * result[i];
	if (y[i] * F[i] < 0) ++error_num;
	w[i] = exp (- y[i] * F[i]);
	margin = std::min (margin, y[i] * F[i] / alpha_sum);
      }

      // normalize weights
      std::transform (w.begin(), w.end(), w.begin(), std::bind2nd(std::divides<double>(), 
								  std::accumulate (w.begin(), w.end(), 0.0)));

      // next rule is estimated
      wbias = 0.0;
      for (unsigned int i = 0; i < l; ++i) wbias += y[i] * w[i];

      // output
      os        << alpha * _y << ' ' << rule.subtree << std::endl;

      std::cout <<  itr << " " << rule_cache.size () << " " << rewritten << "/" << pruned << "/" << total << " "
		<< 1.0 * error_num / l << " " << margin << " " << alpha * _y << " " << rule.subtree << std::endl;
       
      tau = -1e+37;
      double gain = 0.0;
      std::set <rule_t>::iterator mit = rule_cache.end();  
      for (std::set<rule_t>::iterator it = rule_cache.begin(); it != rule_cache.end(); ++it) { 
	double g = calc_gain (it->loc);
	if (tau < std::abs (g)) {
	  tau  = std::abs (g);
	  gain = g;
	  mit  = it;
	}
      }

      rule = *mit;
      rule.gain = gain;

    }

    return true;
  }
};

#define OPT " [-m minsup] [-L maxpat] [-T #round] [-t #subround] [-p #prob] [-t (0|1)] train model"

int main (int argc, char **argv)
{
  extern char *optarg;
  unsigned int maxpat = 0xffffffff;
  unsigned int maxitr = 10000;
  unsigned int minsup = 1;
  unsigned int subitr = 100;
  double prob = 1.0;
  int type = 0;

  if (argc < 2) {
    std::cerr << "Usage: " << argv[0] << OPT << std::endl;
    return -1;
  }

  int opt;
  while ((opt = getopt(argc, argv, "T:L:m:t:p:s:")) != -1) {
    switch(opt) {
    case 't': 
      type = atoi (optarg);
      break;
    case 'p':
      prob = atof (optarg);
      if (prob <= 0 || prob > 1.0) {
	std::cerr << "Prob is invalid " << prob << std::endl;
	return -1;
      }  
      break;
    case 's':
      subitr = atoi (optarg);
      break;
    case 'L':
      maxpat = atoi (optarg);
      break;
    case 'T':
      maxitr = atoi (optarg);
      break;
    case 'm':
      minsup = atoi (optarg);
      break;
    default:
      std::cout << "Usage: " << argv[0] << OPT << std::endl;
      return -1;
    }
  }

  BactLearner bl;
  bl.run (argv[argc-2], argv[argc-1], maxpat, maxitr, minsup, subitr, prob, type);

  return 0;
}
