(* 
    OS-specific code for OCaml XenStore Daemon.
    Copyright (C) 2008 Patrick Colp University of British Columbia

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

let xenstored_proc_domid = "/proc/xen/xsd_domid"
let xenstored_proc_dom0_port = "/proc/xen/xsd_dom0_port"
let xenstored_proc_dom0_mfn = "/proc/xen/xsd_dom0_mfn"
let xenstored_proc_kva = "/proc/xen/xsd_kva"
let xenstored_proc_port = "/proc/xen/xsd_port"

(* Change the permissions for a socket address *)
let xsd_chmod addr =
  match addr with
  | Unix.ADDR_UNIX name -> Unix.chmod name 0o600
  | _ -> Utils.barf_perror "addr -- chmod oops"

(* Get a XenStore daemon directory *)
let xsd_getdir env_var fallback =
  try Sys.getenv env_var with Not_found -> fallback

(* Create the given XenStore daemon directory, if needed *)
let xsd_mkdir name =
  if not (Sys.file_exists name) then Unix.mkdir name 0o755

(* Return the XenStore daemon run directory *)
let xsd_rundir () =
  xsd_getdir "XENSTORED_RUNDIR" "/var/run/xenstored"

(* Return the XenStore daemon path *)
let xsd_socket_path () =
  xsd_getdir "XENSTORED_PATH" ((xsd_rundir ()) ^ "/socket")

(* Return the name of the XenStore daemon read-only socket *)
let xsd_socket_ro () =
  (xsd_socket_path ()) ^ "_ro"

(* Return the name of the XenStore daemon read-write socket *)
let xsd_socket_rw () =
  xsd_socket_path ()

(* Remove the old sockets *)
let xsd_unlink addr =
  match addr with
  | Unix.ADDR_UNIX name -> if Sys.file_exists name then Unix.unlink name
  | _ -> Utils.barf_perror "addr -- unlink oops"

let conn_fds = Hashtbl.create 8
let conn_id = ref (- 1)
let in_set = ref []
let out_set = ref []

(* Accept a connection *)
let accept socket can_write in_set out_set =
  let (fd, _) = Unix.accept socket in
  let interface = new Socket.socket_interface fd can_write in_set out_set in
  let connection = new Connection.connection interface in
  let domu = new Domain.domain !conn_id connection in
  decr conn_id;
  Hashtbl.add conn_fds domu#id fd;
  domu

(* Create and listen to a socket *)
let create_socket socket_name =
  xsd_mkdir (xsd_rundir ());
  let addr = Unix.ADDR_UNIX socket_name
  and socket = Unix.socket Unix.PF_UNIX Unix.SOCK_STREAM 0 in
  xsd_unlink addr;
  Unix.bind socket addr;
  xsd_chmod addr;
  Unix.listen socket 1;
  socket

let filter_conn_fds conn_fds domains =
  let active_conn_ids = List.fold_left (fun ids domain -> if domain#id < 0 then domain#id :: ids else ids) [] domains in
  Hashtbl.iter (fun id fd -> if not (List.mem id active_conn_ids) then Hashtbl.remove conn_fds id) conn_fds

(* Fork daemon *)
let fork_daemon () =
  let pid = Unix.fork () in
  if pid < 0 then Utils.barf_perror ("Failed to fork daemon: " ^ (string_of_int pid));
  if pid <> 0 then exit 0

(* Return the (input) socket connections *)
let get_input_socket_connections conn_fds =
  Hashtbl.fold (fun _ fd rest -> fd :: rest) conn_fds []

(* Return the (output) socket connections *)
let get_output_socket_connections domains conn_fds =
  List.fold_left (fun rest domain -> if domain#can_write then Hashtbl.find conn_fds domain#id :: rest else rest) [] (List.filter (fun domain -> Hashtbl.mem conn_fds domain#id) domains)

(* Read a value from a proc file *)
let read_int_from_proc name =
  let fd = Unix.openfile name [ Unix.O_RDONLY ] 0o600
  and buff = String.create 20 in
  let int = Unix.read fd buff 0 (String.length buff) in
  Unix.close fd;
  if int <> Constants.null_file_descr then int_of_string (String.sub buff 0 int) else Constants.null_file_descr

let socket_rw = create_socket (xsd_socket_rw ())
let socket_ro = create_socket (xsd_socket_ro ())
let special_fds = ref [ socket_rw; socket_ro ]

(* Check connections *)
let check_connections xenstored event_chan =
  filter_conn_fds conn_fds xenstored#domains#domains;
  
  let input_conns = get_input_socket_connections conn_fds
  and output_conns = get_output_socket_connections xenstored#domains#domains conn_fds
  and timeout = xenstored#domains#timeout in
  
  let (i_set, o_set, _) = Unix.select ((if event_chan <> Constants.null_file_descr then Socket.file_descr_of_int event_chan :: !special_fds else !special_fds) @ input_conns) output_conns [] timeout in
  in_set := i_set;
  out_set := o_set;
  
  if List.mem socket_rw !in_set then xenstored#add_domain (accept socket_rw true in_set out_set);
  if List.mem socket_ro !in_set then xenstored#add_domain (accept socket_ro false in_set out_set)

(* Check the event channel for an event *)
let check_event_chan event_chan =
  List.mem (Socket.file_descr_of_int event_chan) !in_set

(* Daemonise *)
let daemonise () =
  (* Separate from parent via fork, so init inherits us *)
  fork_daemon ();
  
  (* Session leader so ^C doesn't whack us *)
  ignore (Unix.setsid ());
  
  (* Let session leader exit so child cannot regain CTTY *)
  fork_daemon ();
  
  (* Move off any mount points we might be in *)
  (try Unix.chdir "/" with _ -> Utils.barf_perror "Failed to chdir");
  
  (* Discard parent's old-fashioned umask prejudices *)
  ignore (Unix.umask 0);
  
  (* Redirect outputs to null device *)
  let dev_null = Unix.openfile "/dev/null" [ Unix.O_RDWR ] 0o600 in
  Unix.dup2 dev_null Unix.stdin;
  Unix.dup2 dev_null Unix.stdout;
  Unix.dup2 dev_null Unix.stderr;
  Unix.close dev_null

(* Return the XenStore domain ID *)
let get_domxs_id () =
  read_int_from_proc xenstored_proc_domid

(* Return the Domain-0 mfn *)
let get_dom0_mfn () =
  read_int_from_proc xenstored_proc_dom0_mfn

(* Return the Domain-0 port *)
let get_dom0_port () =
  read_int_from_proc xenstored_proc_dom0_port

(* Return the pid *)
let get_pid () =
  Unix.getpid ()

(* Return the current time *)
let get_time () =
  let tm = Unix.localtime (Unix.gettimeofday ()) in
  let year = tm.Unix.tm_year + 1900
  and month = tm.Unix.tm_mon + 1
  and day = tm.Unix.tm_mday
  and hour = tm.Unix.tm_hour
  and minute = tm.Unix.tm_min
  and second = tm.Unix.tm_sec in
  Printf.sprintf "%04d%02d%02d %02d:%02d:%02d" year month day hour minute second;;

(* Return the XenBus port *)
let get_xenbus_port () =
  let fd = Unix.openfile xenstored_proc_port [ Unix.O_RDONLY ] 0
  and str = String.create 20 in
  let len = Unix.read fd str 0 (String.length str) in
  Unix.close fd;
  if len <> - 1 then int_of_string (String.sub str 0 len) else Constants.null_file_descr

(* OS specific initialisation *)
let init () =
  ignore (Sys.signal Sys.sigpipe Sys.Signal_ignore)

(* Map XenBus page *)
let map_xenbus port =
  let fd = Unix.openfile xenstored_proc_kva [ Unix.O_RDWR ] 0o600 in
  let interface = new Xenbus.xenbus_interface port (Xenbus.mmap (Socket.int_of_file_descr fd)) in
  Unix.close fd;
  interface

(* Extra option parsing, if needed *)
let parse_option option =
  ()

(* Write PID file *)
let write_pid_file pid_file =
  let fd = Unix.openfile pid_file [ Unix.O_RDWR; Unix.O_CREAT ] 0o600 in
  
  (* Exit silently if daemon already running *)
  (try Unix.lockf fd Unix.F_TLOCK 0 with _ -> ignore (exit 0));
  
  let pid = string_of_int (Unix.getpid ()) in
  let len = String.length pid in
  
  try
    if Unix.write fd pid 0 len <> len then Utils.barf_perror ("Writing pid file " ^ pid_file);
    Unix.close fd
  with _ -> Utils.barf_perror ("Writing pid file " ^ pid_file)
