/******************************************************************************
 * suspend_resume.c
 *
 * Continuously suspends and resumes the domain. Nothing else is done.
 * Adapted from libcheckpoint API.
 *
 * Copyright (c) 2011 Shriram Rajagopalan (rshriram@cs.ubc.ca).
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation;
 * version 2.1 of the License.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
 *
 */

#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <signal.h>
#include <sys/stat.h>
#include <unistd.h>
#include <time.h>

#include <xenctrl.h>
#include <xenguest.h>
#include <xs.h>

typedef enum {
    dt_unknown,
    dt_pv,
    dt_hvm,
    dt_pvhvm /* HVM with PV drivers */
} checkpoint_domtype;

typedef struct {
    xc_interface *xch;
    xc_evtchn *xce;        /* event channel handle */
    struct xs_handle* xsh; /* xenstore handle */
    int watching_shutdown; /* state of watch on @releaseDomain */

    unsigned int domid;
    checkpoint_domtype domtype;
    int fd;

    int suspend_evtchn;

    char* errstr;
} checkpoint_state;


static char errbuf[256];

static int setup_suspend_evtchn(checkpoint_state* s);
static void release_suspend_evtchn(checkpoint_state *s);
static int setup_shutdown_watch(checkpoint_state* s);
static int check_shutdown(checkpoint_state* s);
static void release_shutdown_watch(checkpoint_state* s);

static int evtchn_suspend(checkpoint_state* s);
static int compat_suspend(checkpoint_state* s);
static int pollfd(checkpoint_state* s, int fd);

static int switch_qemu_logdirty(checkpoint_state* s, int enable);
static int suspend_hvm(checkpoint_state* s);
static int suspend_qemu(checkpoint_state* s);
static int resume_qemu(checkpoint_state* s);

/* Returns a string describing the most recent error returned by
 * a checkpoint function. Static -- do not free. */
char* checkpoint_error(checkpoint_state* s)
{
    return s->errstr;
}

void checkpoint_init(checkpoint_state* s)
{
    s->xch = NULL;
    s->xce = NULL;
    s->xsh = NULL;
    s->watching_shutdown = 0;

    s->domid = 0;
    s->domtype = dt_unknown;

    s->suspend_evtchn = -1;

    s->errstr = NULL;
}
void checkpoint_close(checkpoint_state* s);
/* open a checkpoint session to guest domid */
int checkpoint_open(checkpoint_state* s, unsigned int domid)
{
    xc_dominfo_t dominfo;
    unsigned long pvirq;

    s->domid = domid;

    s->xch = xc_interface_open(0,0,0);
    if (!s->xch) {
       s->errstr = "could not open control interface (are you root?)";

       return -1;
    }

    s->xsh = xs_daemon_open();
    if (!s->xsh) {
       checkpoint_close(s);
       s->errstr = "could not open xenstore handle";

       return -1;
    }

    s->xce = xc_evtchn_open(NULL, 0);
    if (s->xce == NULL) {
       checkpoint_close(s);
       s->errstr = "could not open event channel handle";

       return -1;
    }

    if (xc_domain_getinfo(s->xch, s->domid, 1, &dominfo) < 0) {
       checkpoint_close(s);
       s->errstr = "could not get domain info";

       return -1;
    }
    if (dominfo.hvm) {
       if (xc_get_hvm_param(s->xch, s->domid, HVM_PARAM_CALLBACK_IRQ, &pvirq)) {
           checkpoint_close(s);
           s->errstr = "could not get HVM callback IRQ";

           return -1;
       }
       s->domtype = pvirq ? dt_pvhvm : dt_hvm;
    } else
       s->domtype = dt_pv;

    if (setup_shutdown_watch(s) < 0) {
       checkpoint_close(s);

       return -1;
    }

    if (s->domtype == dt_pv) {
	if (setup_suspend_evtchn(s) < 0) {
	    fprintf(stderr, "WARNING: suspend event channel unavailable, "
		    "falling back to slow xenstore signalling\n");
	}
    }
    
    if ((s->domtype > dt_pv) && switch_qemu_logdirty(s, 1))
           return -1;

    return 0;
}

void checkpoint_close(checkpoint_state* s)
{
   if (s->domtype > dt_pv) 
     switch_qemu_logdirty(s, 0);

  release_shutdown_watch(s);
  release_suspend_evtchn(s);

  if (s->xch) {
    xc_interface_close(s->xch);
    s->xch = NULL;
  }
  if (s->xce != NULL) {
    xc_evtchn_close(s->xce);
    s->xce = NULL;
  }
  if (s->xsh) {
    xs_daemon_close(s->xsh);
    s->xsh = NULL;
  }

  s->domid = 0;
}

/* suspend the domain. Returns 0 on failure, 1 on success */
int checkpoint_suspend(checkpoint_state* s)
{
  int rc;

  if (s->suspend_evtchn >= 0)
      rc = evtchn_suspend(s);
  else if (s->domtype == dt_hvm)
      rc = suspend_hvm(s);
  else
      rc = compat_suspend(s);

  return rc < 0 ? 0 : 1;
}

/* let guest execution resume */
int checkpoint_resume(checkpoint_state* s)
{
  int rc;

  if (xc_domain_resume(s->xch, s->domid, 1)) {
    snprintf(errbuf, sizeof(errbuf), "error resuming domain: %d", errno);
    s->errstr = errbuf;

    return -1;
  }

  if (s->domtype > dt_pv && resume_qemu(s) < 0)
      return -1;

  /* restore watchability in xenstore */
  if (xs_resume_domain(s->xsh, s->domid) < 0)
    fprintf(stderr, "error resuming domain in xenstore\n");

  return 0;
}

/* Set up event channel used to signal a guest to suspend itself */
static int setup_suspend_evtchn(checkpoint_state* s)
{
  int port;

  port = xs_suspend_evtchn_port(s->domid);
  if (port < 0) {
    s->errstr = "failed to read suspend event channel";
    return -1;
  }

  s->suspend_evtchn = xc_suspend_evtchn_init(s->xch, s->xce, s->domid, port);
  if (s->suspend_evtchn < 0) {
      s->errstr = "failed to bind suspend event channel";
      return -1;
  }

  fprintf(stderr, "bound to suspend event channel %u:%d as %d\n", s->domid, port,
    s->suspend_evtchn);

  return 0;
}

/* release suspend event channels bound to guest */
static void release_suspend_evtchn(checkpoint_state *s)
{
  /* TODO: teach xen to clean up if port is unbound */
  if (s->xce != NULL && s->suspend_evtchn >= 0) {
    xc_suspend_evtchn_release(s->xch, s->xce, s->domid, s->suspend_evtchn);
    s->suspend_evtchn = -1;
  }
}

static int setup_shutdown_watch(checkpoint_state* s)
{
  char buf[16];

  /* write domain ID to watch so we can ignore other domain shutdowns */
  snprintf(buf, sizeof(buf), "%u", s->domid);
  if ( !xs_watch(s->xsh, "@releaseDomain", buf) ) {
    fprintf(stderr, "Could not bind to shutdown watch\n");
    return -1;
  }
  /* watch fires once on registration */
  s->watching_shutdown = 1;
  check_shutdown(s);

  return 0;
}

/* returns -1 on error or death, 0 if domain is running, 1 if suspended */
static int check_shutdown(checkpoint_state* s) {
    unsigned int count;
    int xsfd;
    char **vec;
    char buf[16];
    xc_dominfo_t info;

    /* for hvms, wait for the xenstore watch */
    if (s->domtype > dt_pv) {
      xsfd = xs_fileno(s->xsh);

      /* loop on watch if it fires for another domain */
      while (1) {
	if (pollfd(s, xsfd) < 0)
	  return -1;

	vec = xs_read_watch(s->xsh, &count);
	if (s->watching_shutdown == 1) {
	  s->watching_shutdown = 2;
	  return 0;
	}
	if (!vec) {
	  fprintf(stderr, "empty watch fired\n");
	  continue;
	}
	snprintf(buf, sizeof(buf), "%d", s->domid);
	if (!strcmp(vec[XS_WATCH_TOKEN], buf))
	  break;
      }
    }

    if (xc_domain_getinfo(s->xch, s->domid, 1, &info) != 1
	|| info.domid != s->domid) {
	snprintf(errbuf, sizeof(errbuf),
		 "error getting info for domain %u", s->domid);
	s->errstr = errbuf;
	return -1;
    }
    if (!info.shutdown) {
	snprintf(errbuf, sizeof(errbuf),
		 "domain %u not shut down", s->domid);
	s->errstr = errbuf;
	return 0;
    }

    if (info.shutdown_reason != SHUTDOWN_suspend)
	return -1;

    return 1;
}

static void release_shutdown_watch(checkpoint_state* s) {
  char buf[16];

  if (!s->xsh)
      return;
  if (!s->watching_shutdown)
      return;

  snprintf(buf, sizeof(buf), "%u", s->domid);
  if (!xs_unwatch(s->xsh, "@releaseDomain", buf))
    fprintf(stderr, "Could not release shutdown watch\n");

  s->watching_shutdown = 0;
}

static int evtchn_suspend(checkpoint_state* s)
{
    int rc;

    rc = xc_evtchn_notify(s->xce, s->suspend_evtchn);
    if (rc < 0) {
	snprintf(errbuf, sizeof(errbuf),
		 "failed to notify suspend event channel: %d", rc);
	s->errstr = errbuf;

	return -1;
    }

    do
	if (!(rc = pollfd(s, xc_evtchn_fd(s->xce))))
	    rc = xc_evtchn_pending(s->xce);
    while (rc >= 0 && rc != s->suspend_evtchn);
    if (rc <= 0)
	return -1;

    if (xc_evtchn_unmask(s->xce, s->suspend_evtchn) < 0) {
	snprintf(errbuf, sizeof(errbuf),
		 "failed to unmask suspend notification channel: %d", rc);
	s->errstr = errbuf;

	return -1;
    }

    if (check_shutdown(s) != 1)
	return -1;

    return 0;
}

/* suspend through xenstore if suspend event channel is unavailable */
static int compat_suspend(checkpoint_state* s)
{
    char path[128];

    sprintf(path, "/local/domain/%u/control/shutdown", s->domid);

    if (!xs_write(s->xsh, XBT_NULL, path, "suspend", 7)) {
	s->errstr = "error signalling qemu logdirty";
	return -1;
    }

    if (check_shutdown(s) != 1)
	return -1;

    return 0;
}

/* returns -1 if fd does not become readable within timeout */
static int pollfd(checkpoint_state* s, int fd)
{
    fd_set rfds;
    struct timeval tv;
    int rc;

    FD_ZERO(&rfds);
    FD_SET(fd, &rfds);

    tv.tv_sec = 5;
    tv.tv_usec = 500000;

    rc = select(fd + 1, &rfds, NULL, NULL, &tv);

    if (rc < 0) {
	snprintf(errbuf, sizeof(errbuf),
		 "error polling fd: %s", strerror(errno));
	s->errstr = errbuf;
    } else if (!rc) {
	snprintf(errbuf, sizeof(errbuf), "timeout polling fd");
	s->errstr = errbuf;
    } else if (! FD_ISSET(fd, &rfds)) {
	snprintf(errbuf, sizeof(errbuf), "unknown error polling fd");
	s->errstr = errbuf;
    } else
	return 0;

    return -1;
}

/* adapted from the eponymous function in xc_save */
static int switch_qemu_logdirty(checkpoint_state *s, int enable)
{
    char path[128];
    char *tail, *cmd, *response;
    char **vec;
    unsigned int len;

    sprintf(path, "/local/domain/0/device-model/%u/logdirty/", s->domid);
    tail = path + strlen(path);

    strcpy(tail, "ret");
    if (!xs_watch(s->xsh, path, "qemu-logdirty-ret")) {
       s->errstr = "error watching qemu logdirty return";
       return 1;
    }
    /* null fire. XXX unify with shutdown watch! */
    vec = xs_read_watch(s->xsh, &len);
    free(vec);

    strcpy(tail, "cmd");
    cmd = enable ? "enable" : "disable";
    if (!xs_write(s->xsh, XBT_NULL, path, cmd, strlen(cmd))) {
       s->errstr = "error signalling qemu logdirty";
       return 1;
    }

    vec = xs_read_watch(s->xsh, &len);
    free(vec);

    strcpy(tail, "ret");
    xs_unwatch(s->xsh, path, "qemu-logdirty-ret");

    response = xs_read(s->xsh, XBT_NULL, path, &len);
    if (!len || strcmp(response, cmd)) {
       if (len)
           free(response);
       s->errstr = "qemu logdirty command failed";
       return 1;
    }
    free(response);
    fprintf(stderr, "qemu logdirty mode: %s\n", cmd);

    return 0;
}

static int suspend_hvm(checkpoint_state *s)
{
    int rc = -1;

    fprintf(stderr, "issuing HVM suspend hypercall\n");
    rc = xc_domain_shutdown(s->xch, s->domid, SHUTDOWN_suspend);
    if (rc < 0) {
       s->errstr = "shutdown hypercall failed";
       return -1;
    }
    fprintf(stderr, "suspend hypercall returned %d\n", rc);

    if (check_shutdown(s) != 1)
       return -1;

    rc = suspend_qemu(s);

    return rc;
}

static int suspend_qemu(checkpoint_state *s)
{
    char path[128];

    fprintf(stderr, "pausing QEMU\n");

    sprintf(path, "/local/domain/0/device-model/%d/command", s->domid);
    if (!xs_write(s->xsh, XBT_NULL, path, "save", 4)) {
       fprintf(stderr, "error signalling QEMU to save\n");
       return -1;
    }

    sprintf(path, "/local/domain/0/device-model/%d/state", s->domid);

    do {
       char* state;
       unsigned int len;

       state = xs_read(s->xsh, XBT_NULL, path, &len);
       if (!state) {
           s->errstr = "error reading QEMU state";
           return -1;
       }

       if (!strcmp(state, "paused")) {
           free(state);
           return 0;
       }

       free(state);
       usleep(1000);
    } while(1);

    return -1;
}

static int resume_qemu(checkpoint_state *s)
{
    char path[128];
    fprintf(stderr, "resuming QEMU\n");

    sprintf(path, "/local/domain/0/device-model/%d/command", s->domid);
    if (!xs_write(s->xsh, XBT_NULL, path, "continue", 8)) {
       fprintf(stderr, "error signalling QEMU to resume\n");
       return -1;
    }

    return 0;
}

static uint64_t tv_delta(struct timeval *new, struct timeval *old)
{
    return (((new->tv_sec - old->tv_sec)*1000000) +
            (new->tv_usec - old->tv_usec));
}

int quit=0;
void stopme(int signum)
{
  quit = 1;
}

int main(int argc, char *argv[])
{
  int domid, interval, runtime = 0;
  checkpoint_state s;
  int iter = 0;
  unsigned int scall, rcall;
  struct timeval suspendCall, suspendTime, resumeTime;
  if (argc <3)
    {
      fprintf(stderr, "usage: suspend_resume <domID (not name!)> interval(ms) [testTime (s)]\n");
      exit(1);
    }
  signal(SIGINT, stopme);
  signal(SIGALRM, stopme);
  signal(SIGTERM, stopme);
  domid = atoi(argv[1]);
  interval = atoi(argv[2]);
  if (argc >3)
    runtime = atoi(argv[3]);

  checkpoint_init(&s);
  if (checkpoint_open(&s, domid) < 0)
    {
      fprintf(stderr, "error setting up suspend interface to dom %d\n",domid);
      exit(1);
    }
  if (runtime)
    alarm(runtime);
  
  while(!quit)
    {
      iter++;
      gettimeofday(&suspendCall,0);
      if (!checkpoint_suspend(&s))
	{
	  fprintf(stderr, "failed to suspend domain %d\n", domid);
	  exit(1);
	}
      gettimeofday(&suspendTime, 0);
      if (checkpoint_resume(&s) < 0)
	{
	  fprintf(stderr, "failed to resume domain %d\n", domid);
	  exit(1);
	}
      gettimeofday(&resumeTime, 0);
      scall = (unsigned int)(tv_delta(&suspendTime,&suspendCall));
      rcall = (unsigned int)(tv_delta(&resumeTime,&suspendTime));

      printf("REMUS:%d:suspendAt:%lu.%06lu:scall:%u:rcall:%u:dcall:%u:suspendFor:%u:ctime:%u:flush:%u:commit:%u:tosend:%u:comp:%u\n",
	     iter, suspendCall.tv_sec, suspendCall.tv_usec, scall, rcall, 0, 0, 0, 0, 0, 0, 0, 0, 0);
      usleep(interval * 1000);
    }
  checkpoint_close(&s);
}
