/* -*-	Mode:C++; c-basic-offset:2; tab-width:2; indent-tabs-mode:t -*- */
#include <stack>
#include <vector>
#include <iostream>
#include <string>

#include "nfFirewall.h"
#include "config.h"


using namespace std;

extern int parseIPT(nfFirewall *f, const char *s);
nfFirewall::nfFirewall(string s, bool isFile)
{
	string data;
  if(isFile){
    data = readFile(s);
  }
  else{
    data = s;
  }
  defaultChains(); 
	parseIPT(this, data.c_str());
	initInputSpace();
	dependencyGraph();
	contextAnalysis();
}


/**
 *  Built-in chains are exposed, so input space is ALL.  
 *
 *  User chains must get input (in)directly from built-in chains It
 *  starts from None and will accumulate when we determine the
 *  transfer function
 */

void nfFirewall::initInputSpace()
{
	map<string, ACL>::iterator pos;
	for (pos = acls.begin();pos != acls.end();++pos){
		if(pos->first == "INPUT" || 
			 pos->first == "FORWARD" || 
			 pos->first == "OUTPUT" || 
			 pos->first == "PREROUTING" ||
			 pos->first == "POSTROUTING") {
			pos->second.InputSpace = bddtrue;
		}
		else{
			pos->second.InputSpace = bddfalse;
		}
	}
}

Rule * nfFirewall::getRule(string s)
{
	size_t t = s.find_first_of(":");
	string name = s.substr(0,t);
	int cnt = atoi(s.substr(t+1).c_str());

	//ACL chain = acls[name];

	map<string, ACL>::iterator it = acls.find(name);
	if (it == acls.end()){
		return NULL;
	}
	else{
		list<Rule>::iterator pos;
		
		int i = 1;
		for (pos = it->second.rules.begin(); pos != it->second.rules.end();++pos, ++i){
			if(i == cnt){
				return &(*pos);
			}
		}
	}
	return NULL;
}

void nfFirewall::dependencyGraph()
{
	using namespace boost;
  Graph g;

  dynamic_properties dp;
  dp.property("id", get(vertex_name, g));
  dp.property("weight", get(edge_weight, g));

  property_map < Graph, vertex_name_t >::type name_map = get(vertex_name, g);
  property_map < Graph, edge_weight_t >::type weight_map = get(edge_weight, g);
  
	std::map<string, ACL>::iterator pos;
	string cname;


  for(pos = acls.begin();pos!=acls.end();++pos){
		cname=  pos->first;
		Vertex tmp = add_vertex(g);
		put(name_map, tmp, cname);
  } 

  for(pos = acls.begin();pos!=acls.end();++pos){
		set<string> deps = pos->second.dependency();
		Vertex a = boost::getVertexbyName(pos->first, g);
		set<string>::iterator it;
		for(it = deps.begin();it != deps.end(); ++it){
			Vertex b = getVertexbyName(*it, g);
			add_edge(a, b, g);
		}
  } 
  std::ofstream  dot("dependency.dot");
  write_graphviz(dot, g, dp, std::string("id") );
}


/**
 *  In IPTables, one can jump to a user chain and return from there,
 *  which is like a function call.  The recent versions of ipt forbids
 *  looped calls.  So we can safely assume it's a DAG  
 *  
 *  This function tries to determine the context of each chain.  For a
 *  chain being called, we need to determine the InputSet.  For a
 *  calling chain, we need to determine what could be returned.  We
 *  start from the bottom of the dependency graph and determine its
 *  transfer function, ie. assume the input is the entire set, what is
 *  the <accept, drop, log, return>.  Recursively, we climb up the
 *  call graph and determine the transfer function of every chain.
 *  Then from the program entry point (INPUT, OUTPUT, FORWARD), we
 *  determine the input to each called chain
 */

void nfFirewall::contextAnalysis()
{
	//determine the transfer function of all chains reachable by the
	//builtIN chains
	acls["INPUT"].transferFunction(*this);
	acls["OUTPUT"].transferFunction(*this);
	acls["FORWARD"].transferFunction(*this);
}

void nfFirewall::StateUpdate(ACLState & state, Rule r)
{
	assert (r.target != "remain");

	

	if(r.terminal()){
		if(state.find(r.target)==state.end()){
			state[r.target] = bddfalse;
		}
		state[r.target] = bdd_or(state[r.target], r.pred.toBDD());
		state["remain"] = bdd_and(state["remain"], bdd_not(r.pred.toBDD()));  
		}
	else{
		//cout << r.target<< endl;
		acls[r.target].transferFunction(*this);
		
		bdd input, ret, acp, dny, log;  //by the target
		
		input = bdd_and(state["remain"], r.pred.toBDD());


		//input space is updated here
		//cout << "input "<< input<< endl;
		acls[r.target].unionInputSpace(input);
		//cout << "inputspace "<< acls[r.target].InputSpace<< endl;

		ret = bdd_and(input, acls[r.target].transFunc["return"]);
		acp = bdd_and(input, acls[r.target].transFunc["accept"]);
		dny = bdd_and(input, acls[r.target].transFunc["deny"]);
		log = bdd_and(input, acls[r.target].transFunc["log"]);
		
		//remain = unmatched + returned by target
		
		state["remain"]=bdd_and(state["remain"], bdd_not(r.pred.toBDD()));
		state["remain"]=bdd_or(state["remain"], ret);
		//accept +=  by target
		state["accept"]=bdd_or(state["accept"], acp);
		//deny += by target
		state["deny"]=bdd_or(state["deny"], dny);
		//log += by target
		state["log"]=bdd_or(state["log"], log);
	}
}



void nfFirewall::policyCheck(ACL& acl)
{

	//	cout << "policy check" << endl;
	ACL blacklist= ACL();

	bdd blocked = getBlackList("./Policy/blacklist.conf", blacklist);

	if (! acl.knownTF){
		acl.transferFunction(*this);
	}

	cout << "\n-----------checking--- "<< acl.name << 
			" --for policy violation-------------" << endl;

	if (bdd_and(blocked, acl.transFunc["accept"])== bddfalse){
		cout << "good! conform to blacklist "<< endl;
	}
	else{
#ifdef DEBUG_TRACE
		cout << "\n\t\t----------------Rules violating policy-------------" << endl;

		list<Rule>::iterator pos;

		int i=1;
		cout << "\n\t\t----------------------------------------------------";
		for (pos = acl.rules.begin(); pos != acl.rules.end(); ++pos){
			if (pos->target=="accept"){
				/**
				 * this may produce superfluous traces, we should check the
				 * incremental accept instead
				 */

				if (bdd_and(pos->toBDD(), blocked)!=bddfalse){
					cout << "\n\t\t " <<  i << " "<<pos->toString();
				}
				++i;
			}
		}
		cout << "\n\t\t---------------------------------------------------" << endl;

		cout << "\n\t\t----------------Policy Violated-------------------" << endl;

		i=1;
		cout << "\n\t\t----------------------------------------------------";
		for (pos = blacklist.rules.begin(); pos != blacklist.rules.end(); ++pos){
			if (bdd_and(pos->toBDD(),acl.transFunc["accept"] )!=bddfalse){
				cout << "\n\t\t " <<  i << " "<<pos->toString();
			}
			++i;
			
		}
		cout << "\n\t\t----------------------------------------------------" << endl;
#endif

	}
}
