#include "Action_RunningAvg.h"
#include "CpptrajStdio.h"

// CONSTRUCTOR
Action_RunningAvg::Action_RunningAvg() : 
  Nwindow_(0),
  d_Nwindow_(0),
  frameThreshold_(0),
  currentWindow_(0),
  windowNatom_(0)
{} 

void Action_RunningAvg::Help() const {
  mprintf("\t[window <value>]\n"
          "  Calculate the running average of coordinates over windows of specified size.\n");
}

// Action_RunningAvg::init()
Action::RetType Action_RunningAvg::Init(ArgList& actionArgs, ActionInit& init, int debugIn)
{
  // Get Keywords
  Nwindow_ = actionArgs.getKeyInt("window",5);
  if (Nwindow_ < 1 ) {
    mprinterr("Error: RunningAvg: window must be >= 1.\n");
    return Action::ERR;
  }

  // Reserve space for Nwindow frames
  Window_.resize( Nwindow_ );
  // Frame above which averaging will start
  frameThreshold_ = Nwindow_ - 1;
  currentWindow_ = 0;
  windowNatom_ = 0;
  // For division of frames, cast Nwindow to double
  d_Nwindow_ = (double)Nwindow_;
  // Get Masks
  // Dataset: Set up by adding to the main data set list.

  mprintf("    RUNNINGAVG: Running average of size %i will be performed over input coords.\n",
          Nwindow_);
# ifdef MPI
  if (init.TrajComm().Size() > 1)
    mprintf("\nWarning: 'runavg' in parallel will not work correctly if coordinates have\n"
              "Warning:   been modified by previous actions (e.g. 'rms').\n"
              "Warning: In addition, certain output trajectory formats may not write correctly\n\n");
# endif
  return Action::OK;
}

// Action_RunningAvg::setup()
Action::RetType Action_RunningAvg::Setup(ActionSetup& setup) {
  // If windowNatom is 0, this is the first setup.
  // If windowNatom is not 0, setup has been called for another parm.
  // Check if the number of atoms has changed. If so the running average
  // will break.
  if ( setup.Top().Natom() != windowNatom_ ) {
    if (windowNatom_!=0) {
      mprintf("Warning: # atoms in topology %s different than previous topology.\n",
              setup.Top().c_str());
      mprintf("Warning:   Running average will NOT be carried over between topologies!\n");
      return Action::SKIP;
    }
    windowNatom_ = setup.Top().Natom();
    // Set up a frame for each window, no masses
    for (int i = 0; i < Nwindow_; i++)
      Window_[i].SetupFrame( windowNatom_ );
    // Setup frame to hold average coords
    avgFrame_.SetupFrame( windowNatom_ );
    // Zero avgFrame
    avgFrame_.ZeroCoords();
    // Set up result
    resultFrame_.SetupFrame( windowNatom_ );
  } 

  // Print info for this parm
  mprintf("\tRunning average set up for %i atoms.\n",windowNatom_);
  return Action::OK;  
}

#ifdef MPI
int Action_RunningAvg::ParallelPreloadFrames(FArray const& preload_frames) {
  int start_idx = (int)preload_frames.size() - Nwindow_ + 1;
  for (int idx = start_idx; idx != (int)preload_frames.size(); idx++) {
    avgFrame_ += preload_frames[idx];
    Window_[currentWindow_++] = preload_frames[idx];
    Window_[currentWindow_].ZeroCoords();
  }
  return 0;
}
#endif

// Action_RunningAvg::action()
Action::RetType Action_RunningAvg::DoAction(int frameNum, ActionFrame& frm) {
  // If frameNum is >= Nwindow, subtract from avgFrame. currentWindow is at
  // the frame that should be subtracted.
  if (frm.TrajoutNum() > frameThreshold_) {
    //mprintf("DBG:\tSubtracting Window[%i] from avgFrame.\n",currentWindow_);
    avgFrame_ -= Window_[currentWindow_];
  }

  // Add current coordinates to avgFrame
  //mprintf("DBG:\tAdding frame %i to avgFrame.\n",frm.TrajoutNum());
  avgFrame_ += frm.Frm();

  // Store current coordinates in Window
  //mprintf("DBG:\tAssigning frame %i to window %i (%i = %i)\n",frm.TrajoutNum(),currentWindow_,
  //        Window_[currentWindow_].natom, frm.Frm().natom);
  Window_[currentWindow_++] = frm.Frm();
  // If currentWindow is out of range, reset
  if (currentWindow_==Nwindow_) currentWindow_=0;

  // If not enough frames to average yet return 3 to indicate further
  // processing should be suppressed.
  if (frm.TrajoutNum() < frameThreshold_)
    return Action::SUPPRESS_COORD_OUTPUT;
  // Otherwise there are enough frames to start processing the running average
  //mprintf("DBG:\tCalculating average for frame %i\n",frm.TrajoutNum());
  resultFrame_.Divide( avgFrame_, d_Nwindow_ );
  // Set frame
  frm.SetFrame( &resultFrame_ );
  return Action::MODIFY_COORDS;
}
