/*	$NetBSD: matcher.c,v 1.1.1.1 2008/12/22 00:18:36 haad Exp $	*/

/*
 * Copyright (C) 2001-2004 Sistina Software, Inc. All rights reserved.  
 * Copyright (C) 2004-2007 Red Hat, Inc. All rights reserved.
 *
 * This file is part of the device-mapper userspace tools.
 *
 * This copyrighted material is made available to anyone wishing to use,
 * modify, copy, or redistribute it subject to the terms and conditions
 * of the GNU Lesser General Public License v.2.1.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include "dmlib.h"
#include "parse_rx.h"
#include "ttree.h"
#include "assert.h"

struct dfa_state {
	int final;
	struct dfa_state *lookup[256];
};

struct state_queue {
	struct dfa_state *s;
	dm_bitset_t bits;
	struct state_queue *next;
};

struct dm_regex {		/* Instance variables for the lexer */
	struct dfa_state *start;
	unsigned num_nodes;
	int nodes_entered;
	struct rx_node **nodes;
	struct dm_pool *scratch, *mem;
};

#define TARGET_TRANS '\0'

static int _count_nodes(struct rx_node *rx)
{
	int r = 1;

	if (rx->left)
		r += _count_nodes(rx->left);

	if (rx->right)
		r += _count_nodes(rx->right);

	return r;
}

static void _fill_table(struct dm_regex *m, struct rx_node *rx)
{
	assert((rx->type != OR) || (rx->left && rx->right));

	if (rx->left)
		_fill_table(m, rx->left);

	if (rx->right)
		_fill_table(m, rx->right);

	m->nodes[m->nodes_entered++] = rx;
}

static void _create_bitsets(struct dm_regex *m)
{
	int i;

	for (i = 0; i < m->num_nodes; i++) {
		struct rx_node *n = m->nodes[i];
		n->firstpos = dm_bitset_create(m->scratch, m->num_nodes);
		n->lastpos = dm_bitset_create(m->scratch, m->num_nodes);
		n->followpos = dm_bitset_create(m->scratch, m->num_nodes);
	}
}

static void _calc_functions(struct dm_regex *m)
{
	int i, j, final = 1;
	struct rx_node *rx, *c1, *c2;

	for (i = 0; i < m->num_nodes; i++) {
		rx = m->nodes[i];
		c1 = rx->left;
		c2 = rx->right;

		if (dm_bit(rx->charset, TARGET_TRANS))
			rx->final = final++;

		switch (rx->type) {
		case CAT:
			if (c1->nullable)
				dm_bit_union(rx->firstpos,
					  c1->firstpos, c2->firstpos);
			else
				dm_bit_copy(rx->firstpos, c1->firstpos);

			if (c2->nullable)
				dm_bit_union(rx->lastpos,
					  c1->lastpos, c2->lastpos);
			else
				dm_bit_copy(rx->lastpos, c2->lastpos);

			rx->nullable = c1->nullable && c2->nullable;
			break;

		case PLUS:
			dm_bit_copy(rx->firstpos, c1->firstpos);
			dm_bit_copy(rx->lastpos, c1->lastpos);
			rx->nullable = c1->nullable;
			break;

		case OR:
			dm_bit_union(rx->firstpos, c1->firstpos, c2->firstpos);
			dm_bit_union(rx->lastpos, c1->lastpos, c2->lastpos);
			rx->nullable = c1->nullable || c2->nullable;
			break;

		case QUEST:
		case STAR:
			dm_bit_copy(rx->firstpos, c1->firstpos);
			dm_bit_copy(rx->lastpos, c1->lastpos);
			rx->nullable = 1;
			break;

		case CHARSET:
			dm_bit_set(rx->firstpos, i);
			dm_bit_set(rx->lastpos, i);
			rx->nullable = 0;
			break;

		default:
			log_error("Internal error: Unknown calc node type");
		}

		/*
		 * followpos has it's own switch
		 * because PLUS and STAR do the
		 * same thing.
		 */
		switch (rx->type) {
		case CAT:
			for (j = 0; j < m->num_nodes; j++) {
				if (dm_bit(c1->lastpos, j)) {
					struct rx_node *n = m->nodes[j];
					dm_bit_union(n->followpos,
						  n->followpos, c2->firstpos);
				}
			}
			break;

		case PLUS:
		case STAR:
			for (j = 0; j < m->num_nodes; j++) {
				if (dm_bit(rx->lastpos, j)) {
					struct rx_node *n = m->nodes[j];
					dm_bit_union(n->followpos,
						  n->followpos, rx->firstpos);
				}
			}
			break;
		}
	}
}

static struct dfa_state *_create_dfa_state(struct dm_pool *mem)
{
	return dm_pool_zalloc(mem, sizeof(struct dfa_state));
}

static struct state_queue *_create_state_queue(struct dm_pool *mem,
					       struct dfa_state *dfa,
					       dm_bitset_t bits)
{
	struct state_queue *r = dm_pool_alloc(mem, sizeof(*r));

	if (!r) {
		stack;
		return NULL;
	}

	r->s = dfa;
	r->bits = dm_bitset_create(mem, bits[0]);	/* first element is the size */
	dm_bit_copy(r->bits, bits);
	r->next = 0;
	return r;
}

static int _calc_states(struct dm_regex *m, struct rx_node *rx)
{
	unsigned iwidth = (m->num_nodes / DM_BITS_PER_INT) + 1;
	struct ttree *tt = ttree_create(m->scratch, iwidth);
	struct state_queue *h, *t, *tmp;
	struct dfa_state *dfa, *ldfa;
	int i, a, set_bits = 0, count = 0;
	dm_bitset_t bs, dfa_bits;

	if (!tt)
		return_0;

	if (!(bs = dm_bitset_create(m->scratch, m->num_nodes)))
		return_0;

	/* create first state */
	dfa = _create_dfa_state(m->mem);
	m->start = dfa;
	ttree_insert(tt, rx->firstpos + 1, dfa);

	/* prime the queue */
	h = t = _create_state_queue(m->scratch, dfa, rx->firstpos);
	while (h) {
		/* pop state off front of the queue */
		dfa = h->s;
		dfa_bits = h->bits;
		h = h->next;

		/* iterate through all the inputs for this state */
		dm_bit_clear_all(bs);
		for (a = 0; a < 256; a++) {
			/* iterate through all the states in firstpos */
			for (i = dm_bit_get_first(dfa_bits);
			     i >= 0; i = dm_bit_get_next(dfa_bits, i)) {
				if (dm_bit(m->nodes[i]->charset, a)) {
					if (a == TARGET_TRANS)
						dfa->final = m->nodes[i]->final;

					dm_bit_union(bs, bs,
						  m->nodes[i]->followpos);
					set_bits = 1;
				}
			}

			if (set_bits) {
				ldfa = ttree_lookup(tt, bs + 1);
				if (!ldfa) {
					/* push */
					ldfa = _create_dfa_state(m->mem);
					ttree_insert(tt, bs + 1, ldfa);
					tmp =
					    _create_state_queue(m->scratch,
								ldfa, bs);
					if (!h)
						h = t = tmp;
					else {
						t->next = tmp;
						t = tmp;
					}

					count++;
				}

				dfa->lookup[a] = ldfa;
				set_bits = 0;
				dm_bit_clear_all(bs);
			}
		}
	}

	log_debug("Matcher built with %d dfa states", count);
	return 1;
}

struct dm_regex *dm_regex_create(struct dm_pool *mem, const char **patterns,
				 unsigned num_patterns)
{
	char *all, *ptr;
	int i;
	size_t len = 0;
	struct rx_node *rx;
	struct dm_pool *scratch = dm_pool_create("regex matcher", 10 * 1024);
	struct dm_regex *m;

	if (!scratch)
		return_NULL;

	if (!(m = dm_pool_alloc(mem, sizeof(*m)))) {
		dm_pool_destroy(scratch);
		return_NULL;
	}

	memset(m, 0, sizeof(*m));

	/* join the regexps together, delimiting with zero */
	for (i = 0; i < num_patterns; i++)
		len += strlen(patterns[i]) + 8;

	ptr = all = dm_pool_alloc(scratch, len + 1);

	if (!all)
		goto_bad;

	for (i = 0; i < num_patterns; i++) {
		ptr += sprintf(ptr, "(.*(%s)%c)", patterns[i], TARGET_TRANS);
		if (i < (num_patterns - 1))
			*ptr++ = '|';
	}

	/* parse this expression */
	if (!(rx = rx_parse_tok(scratch, all, ptr))) {
		log_error("Couldn't parse regex");
		goto bad;
	}

	m->mem = mem;
	m->scratch = scratch;
	m->num_nodes = _count_nodes(rx);
	m->nodes = dm_pool_alloc(scratch, sizeof(*m->nodes) * m->num_nodes);

	if (!m->nodes)
		goto_bad;

	_fill_table(m, rx);
	_create_bitsets(m);
	_calc_functions(m);
	_calc_states(m, rx);
	dm_pool_destroy(scratch);
	m->scratch = NULL;

	return m;

      bad:
	dm_pool_destroy(scratch);
	dm_pool_free(mem, m);
	return NULL;
}

static struct dfa_state *_step_matcher(int c, struct dfa_state *cs, int *r)
{
	if (!(cs = cs->lookup[(unsigned char) c]))
		return NULL;

	if (cs->final && (cs->final > *r))
		*r = cs->final;

	return cs;
}

int dm_regex_match(struct dm_regex *regex, const char *s)
{
	struct dfa_state *cs = regex->start;
	int r = 0;

	if (!(cs = _step_matcher(HAT_CHAR, cs, &r)))
		goto out;

	for (; *s; s++)
		if (!(cs = _step_matcher(*s, cs, &r)))
			goto out;

	_step_matcher(DOLLAR_CHAR, cs, &r);

      out:
	/* subtract 1 to get back to zero index */
	return r - 1;
}
