(* 
    Processing for OCaml XenStore Daemon.
    Copyright (C) 2008 Patrick Colp University of British Columbia

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

(* Check for a valid domain ID *)
let check_domain_id domain_id =
  try int_of_string domain_id >= 0 with _ -> false

(* Check for a valid domain ID (only parameter) *)
let check_domain_id_only payload =
  let domain_id = List.hd (Utils.split payload) in
  String.length domain_id = pred (String.length payload) && check_domain_id domain_id

(* Check for 32-bit integer *)
let check_int int =
  try ignore (Int32.of_string int); true with _ -> false

(* Check introduce *)
let check_introduce payload =
  let split = Utils.split payload in
  let length = List.length split in
  (length = 3 || length = 4) && check_domain_id (List.nth split 0) && check_int (List.nth split 1) && check_int (List.nth split 2)

let rec check_chars path i =
  if i >= String.length path
  then true
  else if not (String.contains Store.valid_characters path.[i])
  then false
  else check_chars path (succ i)

(* Check for a valid path *)
let check_path path =
  if String.length path > 0
  then
    if path.[pred (String.length path)] <> Store.dividor
    then
      if not (Utils.strstr path "//")
      then
        if Store.is_relative path
        then
          if String.length path <= Constants.relative_path_max
          then check_chars path 0
          else false
        else if String.sub path 0 (String.length Store.root_path) = Store.root_path
        then
          if String.length path <= Constants.absolute_path_max
          then check_chars path 0
          else false
        else false
      else false
    else if path = Store.root_path then true else false
  else false

(* Check for a valid path (only parameter) *)
let check_path_only payload =
  let path = Utils.strip_null payload in
  succ (String.length path) = String.length payload && check_path path

let check_permissions payload =
  let split = Utils.split payload in
  let min_length = if payload.[pred (String.length payload)] = Constants.null_char then 2 else 3
  and perm_list = if payload.[pred (String.length payload)] = Constants.null_char then List.tl split else Utils.remove_last (List.tl split) in
  List.length split >= min_length && check_path (List.nth split 0) && List.fold_left (fun accum perm -> accum && (try ignore (Permission.permission_of_string perm); true with _ -> false)) true perm_list

(* Check for a valid transaction end *)
let check_transaction_end payload =
  let value = Utils.strip_null payload in
  succ (String.length value) = String.length payload && (value = Constants.payload_true || value = Constants.payload_false)

(* Check for a valid transaction start *)
let check_transaction_start payload =
  String.length payload = 1 && payload.[0] = Constants.null_char

(* Check for a valid watch path *)
let check_watch_path path =
  if Store.is_event path then check_chars path 0 else check_path path

(* TODO: Check for a valid watch token *)
let check_watch_token token =
  true

(* Check for a valid watch/unwatch *)
let check_watch payload =
  let split = Utils.split payload in
  let length = List.length split in
  (length = 2 || length = 3) && check_watch_path (List.nth split 0) && check_watch_token (List.nth split 1)

let check_write payload =
  let split = Utils.split payload in
  let length = List.length split in
  (length = 1 || length = 2) && check_path (List.nth split 0)

(* Check a message to make sure the payload is valid *)
let check message =
  match message.Message.header.Message.message_type with
  | Message.XS_DIRECTORY -> check_path_only message.Message.payload
  | Message.XS_GET_DOMAIN_PATH -> check_path_only message.Message.payload
  | Message.XS_GET_PERMS -> check_path_only message.Message.payload
  | Message.XS_INTRODUCE -> check_introduce message.Message.payload
  | Message.XS_IS_DOMAIN_INTRODUCED -> check_path_only message.Message.payload
  | Message.XS_MKDIR -> check_path_only message.Message.payload
  | Message.XS_READ -> check_path_only message.Message.payload
  | Message.XS_RELEASE -> check_path_only message.Message.payload
  | Message.XS_RESUME -> check_path_only message.Message.payload
  | Message.XS_RM -> check_path_only message.Message.payload
  | Message.XS_SET_PERMS -> check_permissions message.Message.payload
  | Message.XS_TRANSACTION_END -> check_transaction_end message.Message.payload
  | Message.XS_TRANSACTION_START -> check_transaction_start message.Message.payload
  | Message.XS_UNWATCH -> check_watch message.Message.payload
  | Message.XS_WATCH -> check_watch message.Message.payload
  | Message.XS_WRITE -> check_write message.Message.payload
  | _ -> false

(* Return the list of parent paths that will be created for a given path *)
let rec created_paths store path =
  if store#node_exists path then [] else path :: created_paths store (Store.parent_path path)

(* Return the list of child paths that will be deleted for a given path *)
let rec removed_paths store path =
  match store#read_node path with
  | Store.Children children | Store.Hack (_, children) -> List.fold_left (fun paths child -> paths @ (removed_paths store child#path)) [] children
  | _ -> [ path ]

(* Process a directory message *)
let process_directory domain store xenstored message =
  let path = Store.canonicalise domain (Utils.strip_null message.Message.payload) in
  try
    if xenstored#permissions#check store path Permission.READ domain#id
    then
      let payload =
        match store#read_node path with
        | Store.Children (children) | Store.Hack (_, children) -> List.fold_left (fun children_string child -> if check_path child#path then children_string ^ (Utils.null_terminate (Store.base_path child#path)) else children_string) Constants.null_string children
        | _ -> Constants.null_string in
      domain#add_output_message (Message.reply message payload)
    else domain#add_output_message (Message.error message Constants.EACCES)
  with Constants.Xs_error (errno, _, _) -> domain#add_output_message (Message.error message errno)

(* Process a get domain path message *)
let process_get_domain_path domain store message =
  let domid = Utils.strip_null message.Message.payload in
  let path = Utils.null_terminate (Store.domain_root ^ domid) in
  domain#add_output_message (Message.reply message path)

(* Process a get permissions message *)
let process_get_perms domain store xenstored message =
  let path = Store.canonicalise domain (Utils.strip_null message.Message.payload) in
  if xenstored#permissions#check store path Permission.READ domain#id
  then
    let permissions = xenstored#permissions#get store path in
    let payload = List.fold_left (fun permissions_string permission -> permissions_string ^ (Utils.null_terminate (Permission.string_of_permission permission))) Constants.null_string permissions in
    domain#add_output_message (Message.reply message payload)
  else domain#add_output_message (Message.error message Constants.EACCES)

(* Process an introduce message *)
let process_introduce domain store xenstored message =
  let split = Utils.split message.Message.payload in
  let domid = List.nth split 0
  and mfn = List.nth split 1
  and port = List.nth split 2
  and reserved = if List.length split = 4 then List.nth split 3 else Constants.null_string in
  if not (Domain.is_unprivileged domain)
  then (
    (* XXX: Reserved value *)
    if String.length reserved > 0 then ();
    let domu = Domain.domu_init (int_of_string domid) (int_of_string port) (int_of_string mfn) false in
    xenstored#add_domain domu;
    xenstored#watches#fire_watches "@introduceDomain" (message.Message.header.Message.transaction_id <> 0l) false;
    domain#add_output_message (Message.ack message)
  )
  else domain#add_output_message (Message.error message Constants.EACCES)

(* Process a is domains introduced message *)
let process_is_domain_introduced domain store xenstored message =
  let domid = int_of_string (Utils.strip_null message.Message.payload) in
  let domain_exists = try xenstored#domains#find_by_id domid; true with Not_found -> false in
  let payload = Utils.null_terminate (if domid = Constants.domain_id_self || domain_exists then Constants.payload_true else Constants.payload_false) in
  domain#add_output_message (Message.reply message payload)

(* Process a mkdir message *)
let process_mkdir domain store xenstored message =
  let path = Store.canonicalise domain (Utils.strip_null message.Message.payload)
  and transaction = Transaction.make domain#id message.Message.header.Message.transaction_id in
  (* If permissions exist, node already exists *)
  try
    if xenstored#permissions#check store path Permission.WRITE domain#id
    then domain#add_output_message (Message.ack message)
    else domain#add_output_message (Message.error message Constants.EACCES)
  with _ ->
      try
        if not (store#node_exists path)
        then (
          let paths = created_paths store path in
          store#create_node path;
          xenstored#permissions#add store path domain#id;
          List.iter (fun path -> xenstored#domain_entry_incr store transaction path) paths;
          if message.Message.header.Message.transaction_id = 0l
          then (
            xenstored#transactions#invalidate path;
            xenstored#watches#fire_watches path false false
          )
        );
        domain#add_output_message (Message.ack message)
      with e -> raise e (*domain#add_output_message (Message.error message Constants.EINVAL)*)

(* Process a read message *)
let process_read domain store xenstored message =
  let path = Store.canonicalise domain (Utils.strip_null message.Message.payload) in
  try
    if xenstored#permissions#check store path Permission.READ domain#id
    then
      let payload =
        match store#read_node path with
        | Store.Value value | Store.Hack (value, _) -> value
        | _ -> Constants.null_string in
      domain#add_output_message (Message.reply message payload)
    else domain#add_output_message (Message.error message Constants.EACCES)
  with Constants.Xs_error (errno, _, _) -> domain#add_output_message (Message.error message errno)

(* Process a release message *)
let process_release domain store xenstored message =
  if domain#id <= 0
  then
    let domu_id = int_of_string (Utils.strip_null message.Message.payload) in
    try
      xenstored#remove_domain (xenstored#domains#find_by_id domu_id);
      if domu_id > 0 then xenstored#watches#fire_watches "@releaseDomain" false false;
      domain#add_output_message (Message.ack message)
    with Not_found -> domain#add_output_message (Message.error message Constants.ENOENT)
  else domain#add_output_message (Message.error message Constants.EACCES)

(* Process a rm message *)
let process_rm domain store xenstored message =
  let path = Store.canonicalise domain (Utils.strip_null message.Message.payload)
  and transaction = Transaction.make domain#id message.Message.header.Message.transaction_id in
  try
    if store#node_exists path
    then
      if xenstored#permissions#check store path Permission.WRITE domain#id
      then
        if path <> Store.root_path
        then (
          let paths = removed_paths store path in
          List.iter (fun path -> xenstored#domain_entry_decr store transaction path) paths;
          store#remove_node path;
          xenstored#permissions#remove store path;
          if message.Message.header.Message.transaction_id = 0l
          then (
            xenstored#transactions#invalidate path;
            xenstored#watches#fire_watches path false true
          );
          domain#add_output_message (Message.ack message)
        )
        else domain#add_output_message (Message.error message Constants.EINVAL)
      else domain#add_output_message (Message.error message Constants.EACCES)
    else if store#node_exists (Store.parent_path path)
    then
      if xenstored#permissions#check store (Store.parent_path path) Permission.WRITE domain#id
      then domain#add_output_message (Message.ack message)
      else domain#add_output_message (Message.error message Constants.EACCES)
    else domain#add_output_message (Message.error message Constants.ENOENT) (* XXX: This might be wrong *)
  with Constants.Xs_error (errno, _, _) -> domain#add_output_message (Message.error message errno)

(* Process a set permissions message *)
let process_set_perms domain store xenstored message =
  let split = Utils.split message.Message.payload in
  let path = Store.canonicalise domain (List.hd split) in
  let (permissions, reserved) =
    if message.Message.payload.[pred (String.length message.Message.payload)] = Constants.null_char
    then (List.tl split, Constants.null_string)
    else (Utils.remove_last (List.tl split), List.nth split (pred (List.length split))) in
  if xenstored#permissions#check store path Permission.WRITE domain#id
  then (
    (* XXX: Reserved value *)
    if String.length reserved > 0 then ();
    try
      xenstored#permissions#set permissions store path;
      xenstored#watches#fire_watches path (message.Message.header.Message.transaction_id <> 0l) false;
      domain#add_output_message (Message.ack message)
    with _ -> domain#add_output_message (Message.error message Constants.EACCES) (* XXX: errno? *)
  )
  else domain#add_output_message (Message.error message Constants.EACCES)

(* Process a transaction end message *)
let process_transaction_end domain store xenstored message =
  let transaction = Transaction.make domain#id message.Message.header.Message.transaction_id in
  if xenstored#transactions#exists transaction
  then (
    Trace.destroy domain#id "transaction";
    if Utils.strip_null message.Message.payload = Constants.payload_true
    then
      if xenstored#commit transaction
      then domain#add_output_message (Message.ack message)
      else domain#add_output_message (Message.error message Constants.EAGAIN)
    else domain#add_output_message (Message.ack message)
  )
  else domain#add_output_message (Message.error message Constants.ENOENT)

(* Process a transaction start message *)
let process_transaction_start domain store xenstored message =
  try
    if message.Message.header.Message.transaction_id = 0l
    then
      let transaction = xenstored#new_transaction domain store in
      let payload = Utils.null_terminate (Int32.to_string transaction.Transaction.transaction_id) in
      domain#add_output_message (Message.reply message payload)
    else domain#add_output_message (Message.error message Constants.EBUSY)
  with Constants.Xs_error (errno, _, _) -> domain#add_output_message (Message.error message errno)

(* Process an unwatch message *)
let process_unwatch domain store xenstored message =
  let split = Utils.split message.Message.payload in
  let path = List.nth split 0
  and token = List.nth split 1
  and reserved = if List.length split = 3 then List.nth split 2 else Constants.null_string in
  let relative = Store.is_relative path in
  let actual_path = if relative then Store.canonicalise domain path else path in
  (* XXX: Reserved value *)
  if String.length reserved > 0 then ();
  if xenstored#watches#remove (Watch.make domain actual_path token relative)
  then (
    Trace.destroy domain#id "watch";
    domain#add_output_message (Message.ack message)
  )
  else domain#add_output_message (Message.error message Constants.ENOENT)

(* Process a watch message *)
let process_watch domain store xenstored message =
  let split = Utils.split message.Message.payload in
  let path = List.nth split 0
  and token = List.nth split 1
  and reserved = if List.length split = 3 then List.nth split 2 else Constants.null_string in
  let relative = Store.is_relative path in
  let actual_path = if relative then Store.canonicalise domain path else path in
  (* XXX: Reserved value *)
  if String.length reserved > 0 then ();
  if xenstored#add_watch domain (Watch.make domain actual_path token relative)
  then (
    Trace.create domain#id "watch";
    domain#add_output_message (Message.ack message);
    domain#add_output_message (Message.event ((Utils.null_terminate path) ^ (Utils.null_terminate token)))
  )
  else domain#add_output_message (Message.error message Constants.EEXIST)

(* Process a write message *)
let process_write domain store xenstored message =
  let split = Utils.split message.Message.payload in
  let path = Store.canonicalise domain (List.hd split)
  and value = Utils.combine (List.tl split) in
  let transaction = Transaction.make domain#id message.Message.header.Message.transaction_id in
  if not (store#node_exists path) || xenstored#permissions#check store path Permission.WRITE domain#id
  then
    if Domain.is_unprivileged domain && String.length value >= xenstored#options.Option.quota_max_entry_size
    then domain#add_output_message (Message.error message Constants.ENOSPC)
    else
      try
        if not (store#node_exists path)
        then (
          let paths = created_paths store path in
          store#create_node path;
          xenstored#permissions#add store path domain#id;
          List.iter (fun path -> xenstored#domain_entry_incr store transaction path) paths
        );
        store#write_node path value;
        if message.Message.header.Message.transaction_id = 0l
        then (
          xenstored#transactions#invalidate path;
          xenstored#watches#fire_watches path false false
        );
        domain#add_output_message (Message.ack message)
      with e -> raise e (*domain#add_output_message (Message.error message Constants.EINVAL)*) (* XXX: Wrong error? *)
  else domain#add_output_message (Message.error message Constants.EACCES)

(* Process a message *)
let process (xenstored : Xenstored.xenstored) domain =
  let message = domain#input_message in
  let store = xenstored#transactions#store (Transaction.make domain#id message.Message.header.Message.transaction_id) in
  if check message
  then (
    match message.Message.header.Message.message_type with
    | Message.XS_DIRECTORY -> process_directory domain store xenstored message
    | Message.XS_GET_DOMAIN_PATH -> process_get_domain_path domain store message
    | Message.XS_GET_PERMS -> process_get_perms domain store xenstored message
    | Message.XS_INTRODUCE -> process_introduce domain store xenstored message
    | Message.XS_IS_DOMAIN_INTRODUCED -> process_is_domain_introduced domain store xenstored message
    | Message.XS_MKDIR -> process_mkdir domain store xenstored message
    | Message.XS_READ -> process_read domain store xenstored message
    | Message.XS_RELEASE -> process_release domain store xenstored message
    | Message.XS_RM -> process_rm domain store xenstored message
    | Message.XS_SET_PERMS -> process_set_perms domain store xenstored message
    | Message.XS_TRANSACTION_END -> process_transaction_end domain store xenstored message
    | Message.XS_TRANSACTION_START -> process_transaction_start domain store xenstored message
    | Message.XS_UNWATCH -> process_unwatch domain store xenstored message
    | Message.XS_WATCH -> process_watch domain store xenstored message
    | Message.XS_WRITE -> process_write domain store xenstored message
    | _ -> domain#add_output_message (Message.error message Constants.EINVAL)
  )
  else domain#add_output_message (Message.error message Constants.EINVAL)
