//  -*-	C++ -*-
//
// Methods for Markov Decision Problems
//
// The MDP class defines a set of methods for manipulating Markov
// Decision Problems including value iteration,	policy iteration, and
// prioritized sweeping	using tabular methods.
//
// All of these	classes	are parameterized using	templates in terms of
// two types:  State and Action.  Values of these types	are passed to
// methods of the Problem class	in order to manipulate them.

#ifndef	MDP_H
#define	MDP_H

#include "vector.h"
#include "list.h"
#include "priorque.h"
#include "clrandom.h"
#include "minmax.h"

extern int VERBOSE;
extern int BackupCounter;  // count number of "Q" backups.
extern int OldPopCounter;  // count number of PQ items discarded for
			   // being old
extern int MCinterval;	   // perform MonteCarlo evaluation every
			   // MCinterval Q updates.
extern int MCNSteps;	   // Number of	steps to perform MC

extern int MCpolicy;
const int PGreedy = 0;
const int PRandom = 1;

extern int ActionCounter;  // count primitive actions with environment

template <class	State>
class ResultInfo
{
public:
  ResultInfo() {};
  ResultInfo(float p, float r, State s): probability(p),
					 reward(r),
					 state(s) {};
  float	probability;
  float	reward;
  State	state;
};

template <class	State>
istream	& operator >> (istream &, ResultInfo<State> &);
template <class	State>
ostream	& operator << (ostream &, ResultInfo<State> &);


template <class	State, class Action>
class SuccessorInfo
{
public:
  // constructors
  SuccessorInfo() {};
  SuccessorInfo(Action a): action(a), timer(-1)	{};
  Action action;
  int timer;	  // every time	we do a	Q backup here, we set this timer
  list<ResultInfo<State> *> resultStates;
};

template <class	State, class Action>
istream	& operator >> (istream &, SuccessorInfo<State, Action> &);
template <class	State, class Action>
ostream	& operator << (ostream &, SuccessorInfo<State, Action> &);

template <class	State, class Action>
class PredecessorInfo
{
public:
  PredecessorInfo(State	s, Action a, float prob	= 0.0):
    state(s),
    action(a),
    probability(prob) {};
  State	 state;
  Action  action;
  float	  probability;
};


//////////////////////////////////////////////////////////////////////
//
// class Problem
//
// This	is a generic MDP interface.  We	have made some odd modeling
// choices.  For example, normally, one	might make functions such as
// printState, applicable, terminated, and execute be member functions
// of a	State class.  We have instead made them	member functions of
// the Problem.	 The reason for	this is	to keep	the data structures
// for states and actions very simple (even integers), and to put
// messy things, like random number generators,	into the problem.
//

template <class	State, class Action>
class Problem
{
public:
  virtual float	discountFactor() = 0;
  virtual unsigned int maxState() = 0;
  virtual unsigned int maxAction() = 0;
  virtual State	nullState() = 0;
  virtual Action nullAction() =	0;
  virtual Action randomAction()	= 0;
  virtual iterator<State> * stateIterator() = 0;
  virtual iterator<Action> * actionIterator() =	0;
  virtual list<SuccessorInfo<State, Action> *> * successors(State) = 0;
  virtual list<PredecessorInfo<State, Action> *> * predecessors(State) = 0;
  virtual unsigned int stateIndex(State) = 0;
  virtual unsigned int actionIndex(Action) = 0;
  virtual void printState(ostream & str, State & s) = 0;
  virtual void printAction(ostream & str, Action & a) =	0;
  virtual void init(State & s) = 0;
  virtual int applicable(Action	& a, State & s)	= 0;
  virtual int terminated(State & s) = 0;
  virtual float	execute(Action & a, State & s) = 0;
  virtual clrandom & rng() = 0;
};

// this	is the record that is placed onto the priority queue.
template <class	State, class Action>
class PQInfo
{
public:
  State	*    state;
  Action *   action;
  float	     delta;
  int	     timer;  //	we use the backup counter as a timer

  // constructors
  PQInfo(State * s, Action * a,	float d, int t):
    state(s),
    action(a),
    delta(d),
    timer(t) {};

  // used for empty priority queue.
  PQInfo(): state(0), action(0), delta(0.0), timer(0) {};

  // we	make this a maximizing priority	queue by negating the
  // priorities	in this	comparison.
  int operator < (PQInfo<State,	Action>	& other) {
    return ((-delta) < (-other.delta));
  };
};

//////////////////////////////////////////////////////////////////////
//
// class MDP
//
// We have made	this a separate	class from a Problem.  This is a bit
// odd as well.	 This class supports all of the	MDP algorithms and
// also	contains the necessary data structures (e.g., the Q function
// and the priority queue).
//

template <class	State, class Action>
class MDP
{
public:
  MDP(Problem<State, Action> * p);

  // access the	value function
  float	Vfunction(State);
  float	& Qfunction(State, Action);

  // main algorithms
  void ValueIteration(float epsilon);
  void PrioritizedSweeping(int n, float	epsilon);
  // supporting	algorithms
  float	Qbackup(State, Action);
  float	BellmanBackup(State &);

  void PushPredecessors(State s, float delta, float epsilon);
  void PSweep(int n, float epsilon);
  void PrintValueFunction(ostream &);
  Action greedyPolicy(State);
  void MonteCarloEval(int n);
  Action randomPolicy(State);
  Action explorationPolicy(State &);

protected:
  vector<float>	Q;		    // the Q table
  Problem<State, Action> * problem; // the problem we are solving
  unsigned int nActions;	 // the	max num	of actions, for	fast indexing
  heap<PQInfo<State, Action> > PQ;
  vector<float>	eligibility;	    // the eligibility vector for
				    // TD(lambda)-style	methods.
  float	beta;			 // probability	of a random action
  float	betaDecay;		 // amount to reduce beta after	each
				 // action (additive)
};

//////////////////////////////////////////////////////////////////////
//
// class implementations
//

template <class	State, class Action>
MDP<State, Action>::MDP(Problem<State, Action> * p):
  problem(p),
  Q(p->maxState() * p->maxAction(), 0.0),
  nActions(p->maxAction()),
  PQ(p->maxState() * p->maxAction())
{
}

template <class	State, class Action>
float MDP<State, Action>::Vfunction(State state)
{
  if (problem->terminated(state)) return 0.0;

  list<SuccessorInfo<State, Action> *> * succs = problem->successors(state);
  assert(succs);
  listIterator<SuccessorInfo<State, Action> *> itr(*succs);
  float	answer = 0.0;
  int firstTime	= 1;
  unsigned int base = nActions * problem->stateIndex(state);
  for (itr.init(); !itr; ++itr)	{
    unsigned int index = problem->actionIndex(itr()->action) + base;
    float qvalue = Q[index];
    if (firstTime) {
      answer = qvalue;
      firstTime	= 0;
    }
    else if (qvalue > answer) answer = qvalue;
  }
  return answer;
}

template <class	State, class Action>
float &	MDP<State, Action>::Qfunction(State state, Action action)
{
  unsigned int index = (problem->stateIndex(state) * nActions +
			problem->actionIndex(action));
  return Q[index];
}

template <class	State, class Action>
float MDP<State, Action>::Qbackup(State	state, Action action)
{
  list<SuccessorInfo<State, Action> *> * succs = problem->successors(state);
  assert(succs);

  // find the desired action in	the successor information
  listIterator<SuccessorInfo<State, Action> *> itr(*succs);
  for (itr.init(); !itr; ++itr)	{
    if (itr()->action == action) {
      float answer = 0.0;

      // loop through successor	states and compute Bellman formula
      listIterator<ResultInfo<State> *>	ritr(itr()->resultStates);
      for(ritr.init(); !ritr; ++ritr) {
	answer += ritr()->probability *	(ritr()->reward	+
					 problem->discountFactor() *
					 Vfunction(ritr()->state));
      }
      Qfunction(state, action) = answer;
      BackupCounter++;
      if (BackupCounter	% MCinterval ==	0) MonteCarloEval(MCNSteps);
      itr()->timer = BackupCounter;
      return answer;
    }
  }
}

//
// Perform a bellman backup
// Returns the bellman error
//
template <class	State, class Action>
float MDP<State, Action>::BellmanBackup(State &	state)
{
  if (VERBOSE) {
    cout << "Starting BellmanBackup for	state ";
    problem->printState(cout, state);
    cout << endl;
  }

  // we	update the Q value for each action
  list<SuccessorInfo<State, Action> *> * succs = problem->successors(state);
  assert(succs);
  float	oldValue = 0.0;
  float	newValue = 0.0;
  int firstTime	= 1;

  listIterator<SuccessorInfo<State, Action> *> itr(*succs);
  for (itr.init(); !itr; ++itr)	{
    if (VERBOSE) cout << "  Action = " << itr()->action	<< endl;
    float qvalue = 0.0;

    // loop through successor states and compute Bellman formula
    listIterator<ResultInfo<State> *> ritr(itr()->resultStates);
    for(ritr.init(); !ritr; ++ritr) {
      if (VERBOSE) {
	cout <<	"    prob=" << ritr()->probability
	     <<	" reward=" << ritr()->reward
	     <<	" state'=" << ritr()->state
	     <<	" V(state') = "	<< Vfunction(ritr()->state)
	     <<	endl;
      }
      qvalue +=	ritr()->probability * (ritr()->reward +
				       problem->discountFactor() *
				       Vfunction(ritr()->state));
    }

    // compute new and old values
    float & cell = Qfunction(state, itr()->action);
    if (VERBOSE) {
      cout << "	 Qold =	" << cell << " Qnew = "	<< qvalue << endl;
    }

    if (firstTime) {
      oldValue = cell;
      newValue = qvalue;
      firstTime	= 0;
    }
    else {
      if (cell > oldValue) oldValue = cell;
      if (qvalue > newValue) newValue =	qvalue;
    }
    // this updates the	Qvalue
    cell = qvalue;
    BackupCounter++;
    if (BackupCounter %	MCinterval == 0) MonteCarloEval(MCNSteps);
  }
  // note: if there were no actions at all, this will return 0.0 for
  // the Bellman error.
  if (VERBOSE) {
    cout << "Bellman error=" <<	fabs(oldValue -	newValue) << endl;
  }
  return fabs(oldValue - newValue);
}

template <class	State, class Action>
void MDP<State,	Action>::ValueIteration(float epsilon)
{
  iterator<State> & itr	= *(problem->stateIterator());

  int epsilonExceeded =	0;
  int iteration	= 0;
  do {
    iteration++;
    epsilonExceeded = 0;
    float totalBellmanError = 0.0;
    for	(itr.init(); !itr; ++itr) {
      float bellmanerror = BellmanBackup(itr());
      if (bellmanerror > epsilon) epsilonExceeded = 1;
      totalBellmanError	+= bellmanerror;
    }
    cout << "End of iteration "	<< iteration
	 << " BackupCounter = "	<< BackupCounter
	 << " total bellman error = " << totalBellmanError << endl;
    if (VERBOSE) PrintValueFunction(cout);
  }
  while	(epsilonExceeded);
}

template <class	State, class Action>
void MDP<State,	Action>::PushPredecessors(State	s, float delta,	float epsilon)
{
  list<PredecessorInfo<State, Action> *> * preds = problem->predecessors(s);
  if (!preds) return;

  listIterator<PredecessorInfo<State, Action> *> itr(*preds);
  for(itr.init(); !itr;	++itr) {
    float prob = itr()->probability;
    if (prob * delta > epsilon)	{
      State * ns = new State(itr()->state);
      Action * na = new	Action(itr()->action);
      PQInfo<State, Action> info(ns, na, prob, BackupCounter);
      PQ.add(info);
    }
  }
}

template <class	State, class Action>
void MDP<State,	Action>::PSweep(int n, float epsilon)
{
  cout << "# PSweep: n = " << n	<< " pq	= " << PQ.size() << endl;
  // perform n backups using prioritized sweeping
  for (int i = 0; i < n; i++) {
    if (i % 100	== 0) {
      cout << PQ.size()	<< " "
	   << BackupCounter << " "
	   << OldPopCounter <<	" size backups oldpops"	<< endl;
    }
    if (PQ.isEmpty()) return;
    PQInfo<State, Action> info = PQ.deleteMin();
    State * s =	info.state;
    Action * a = info.action;

    // check if	we have	already	backed up this state more recently.
    list<SuccessorInfo<State, Action> *> * succs = problem->successors(*s);
    assert(succs);
    // find the	desired	action in the successor	information
    listIterator<SuccessorInfo<State, Action> *> itr(*succs);
    for	(itr.init(); !itr; ++itr) {
      if (itr()->action	== *a) {
	// found the desired item.  If the Q value has been updated
	// since this PQ item was created, then	we ignore this PQ item.
	//	cout <<	"Q timer = " <<	itr()->timer <<	" PQitem timer = " <<
	//	  info.timer <<	endl;
	if (itr()->timer > info.timer) {
	  OldPopCounter++;
	  break;
	}
	float Vold = Vfunction(*s);
	float Qnew = Qbackup(*s, *a);
	if (Qnew > Vold) {
	  float	delta =	fabs(Qnew - Vold);
	  PushPredecessors(*s, delta, epsilon);
	}

	// do full bellman backups instead?
	//    float delta = BellmanBackup(*s);
	//    PushPredecessors(*s, delta, epsilon);

	break;
      }
    }
    delete s;
    delete a;
  }
}

template <class	State, class Action>
void MDP<State,	Action>::PrioritizedSweeping(int n, float epsilon)
{
  // n = number	of steps of prioritized	sweeping to do after each
  // bellman backup.
  // epsilon = cutoff
  iterator<State> & itr	= *(problem->stateIterator());

  int epsilonExceeded =	0;
  int iteration	= 0;
  for (itr.init(); !itr; ++itr)	{
    float delta	= BellmanBackup(itr());
    PushPredecessors(itr(), delta, epsilon);
    PSweep(n, epsilon);
  }
  PSweep(99999999, epsilon);
  cout << PQ.size() << " "
       << BackupCounter	<< " "
       << OldPopCounter	<<  " size backups oldpops FINAL" << endl;
}

// IO routines
template <class	State>
istream	& operator >> (istream & str, ResultInfo<State>	& ri)
{
  char delim;
  str >> delim;
  if (delim != '(') {
    cerr << "While reading ResultInfo, expected	'(' but	read '"	<< delim << "'." << endl;
    abort();
  }
  str >> ri.probability;
  str >> ri.reward;
  str >> ri.state;
  str >> delim;
  if (delim != ')') {
    cerr << "while reading ResultInfo, expected	')' but	read '"	<<
      delim << "'." << endl;
    abort();
  }
  return str;
}

template <class	State>
istream	& operator >> (istream & str, ResultInfo<State>	* ri)
{
  return str >>	(*ri);
}

template <class	State>
ostream	& operator << (ostream & str, ResultInfo<State>	& ri)
{
  str << "(" <<	ri.probability << " "
      << ri.reward << "	"
      << ri.state << ")";
  return str;
}

template <class	State>
ostream	& operator << (ostream & str, ResultInfo<State>	* ri)
{
  return str <<	*ri;
}

template <class	State, class Action>
istream	& operator >> (istream & str, SuccessorInfo<State, Action> & si)
{
  char delim;
  str >> delim;
  if (delim != '(') {
    cerr << "While reading SuccessorInfo expected '(' but read '"
	 << delim << "'." << endl;
    abort();
  }
  str >> si.action;
  str >> si.resultStates;
  str >> delim;
  if (delim != ')') {
    cerr << "While reading SuccessorInfo expected ')' but read '"
	 << delim << "'." << endl;
  }
  return str;
}

template <class	State, class Action>
istream	& operator >> (istream & str, SuccessorInfo<State, Action> * si)
{
  return str >>	(*si);
}

template <class	State, class Action>
ostream	& operator << (ostream & str, SuccessorInfo<State, Action> & si)
{
  str << "(" <<	si.action << " " << si.resultStates << ")";
  return str;
}

template <class	State, class Action>
ostream	& operator << (ostream & str, SuccessorInfo<State, Action> *
		       si)
{
  return str <<	*si;
}

template <class	State, class Action>
istream	& operator >> (istream & istr,
		       list<SuccessorInfo<State, Action> *> & target)
{
  // read in a list of objects surrounded by parentheses and separated by spaces
  // we	push them onto the target list and then	reverse	the list
  assert(target.isEmpty()); // target must be empty

  char delim;
  istr >> delim;
  if (delim != '(') {
    cerr << "Found `" << delim << "' where `(' was expected." << endl;
  }
  assert(delim == '(');
  do {
    istr >> delim;
    if (delim == ')') break;
    istr.putback(delim);
    SuccessorInfo<State, Action> * val = new SuccessorInfo<State, Action>;
    istr >> val;
    target.add(val);
  } while (istr);
  target.reverse();
  return istr;
}

template <class	State>
istream	& operator >> (istream & istr,
		       list<ResultInfo<State> *> & target)
{
  // read in a list of objects surrounded by parentheses and separated by spaces
  // we	push them onto the target list and then	reverse	the list
  assert(target.isEmpty()); // target must be empty

  char delim;
  istr >> delim;
  if (delim != '(') {
    cerr << "Found `" << delim << "' where `(' was expected." << endl;
  }
  assert(delim == '(');
  do {
    istr >> delim;
    if (delim == ')') break;
    istr.putback(delim);
    ResultInfo<State> *	val = new ResultInfo<State>;
    istr >> val;
    target.add(val);
  } while (istr);
  target.reverse();
  return istr;
}

template <class	State, class Action>
void MDP<State,Action>::PrintValueFunction(ostream & str)
{
  iterator<State> & itr	= *(problem->stateIterator());

  for(itr.init(); !itr;	++itr) {
    problem->printState(str, itr());
    cout << " "	<< Vfunction(itr()) << endl;
  }
}

template <class	State, class Action>
Action MDP<State,Action>::greedyPolicy(State state)
{
  list<SuccessorInfo<State, Action> *> * succs = problem->successors(state);
  assert(succs);
  float	bestValue = 0.0;
  Action bestAction = problem->nullAction();
  int firstTime	= 1;
  listIterator<SuccessorInfo<State, Action> *> itr(*succs);
  for (itr.init(); !itr; ++itr)	{
    float qvalue = Qfunction(state, itr()->action);
    if (firstTime) {
      bestValue	= qvalue;
      bestAction = itr()->action;
      firstTime	= 0;
    }
    else if (qvalue > bestValue) {
      bestValue	= qvalue;
      bestAction = itr()->action;
    }
  }
  return bestAction;
}

template <class	State, class Action>
Action MDP<State,Action>::randomPolicy(State state)
{
  return problem->randomAction();
}


template <class	State, class Action>
void MDP<State,Action>::MonteCarloEval(int n)
{
  // perform Monte Carlo simulation for	n steps	and report results.

  float	totalReward = 0.0;
  double gamma = problem->discountFactor();
  Action action	= problem->nullAction();

  State	state =	problem->nullState();
  problem->init(state);
  for (int i = 0; i < n; i++) {
    if (problem->terminated(state)) problem->init(state);

    switch(MCpolicy) {
    case PGreedy: action = greedyPolicy(state);	break;
    case PRandom: action = randomPolicy(state);	break;
    default: {
      cerr << "Unknown policy requested	in MonteCarloEval" << endl;
      abort();
    }
    }
    totalReward	= (problem->execute(action, state) +
		   gamma * totalReward);
  }
  cout << "MonteCarloEval: " <<	totalReward << endl;
}

template <class	State, class Action>
Action MDP<State, Action>::explorationPolicy(State & state)
{
  double r = problem->rng().between(0.0, 1.0);
  beta = max(0.0, beta - betaDecay);
  if (r	< beta)	{
    // choose a	random action
    return randomPolicy(state);
  }
  else {
    return greedyPolicy(state);
  }
}

#endif
