#!/usr/bin/env python
#
# Copyright (C) 2010 Oracle. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, version 2.  This program is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.  You should have received a copy of the GNU
# General Public License along with this program; if not, write to the Free
# Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 021110-1307,
# USA.
#
# Some code from virtinst:
#
#      LICENSE: GPLv2
#      Copyright 2006-2007 Red Hat, Inc.
#      Author: Daniel P. Berrange <berrange@redhat.com>
#
# Authors:
#
#  - Zhigang Wang <zhigang.x.wang@oracle.com>

import sys
import os
import stat
import time
import string
import random
import tempfile
import commands
import subprocess
import urlgrabber
from optparse import OptionParser


xen_paths = [
    ('images/xen/vmlinuz', 'images/xen/initrd.img'), # Fedora
    ('boot/i386/vmlinuz-xen', 'boot/i386/initrd-xen'), # Opensuse > 10.2 and sles 10
    ('boot/x86_64/vmlinuz-xen', 'boot/x86_64/initrd-xen'), # Opensuse > 10.2 and sles 10
    ('current/images/netboot/xen/vmlinuz', 'current/images/netboot/xen/initrd.gz'), # Debian
]


class Fetcher:
    def __init__(self, location, tmpdir='/var/run/xend/boot'):
        self.location = location
        self.tmpdir = tmpdir
        self.srcdir = location

    def prepare(self):
        if not os.path.exists(self.tmpdir):
            os.makedirs(self.tmpdir, 0750)

    def cleanup(self):
        pass

    def get_file(self, filename):
        url = os.path.join(self.srcdir, filename)
        suffix = ''.join(random.sample(string.ascii_letters, 6))
        local_name = os.path.join(self.tmpdir, 'xenpvboot.%s.%s' % (os.path.basename(filename), suffix))
        try:
            return urlgrabber.urlgrab(url, local_name, copy_local=1)
        except Exception, e:
            raise RuntimeError('Cannot get file %s: %s' % (url, e))


class MountedFetcher(Fetcher):
    def _mount(self, dev, path, option=''):
        if os.uname()[0] == 'SunOS':
            mountcmd = '/usr/sbin/mount'
        else:
            mountcmd = '/bin/mount'
        cmd = ' '.join([mountcmd, option, dev, path])
        (status, output) = commands.getstatusoutput(cmd)
        if status != 0:
            raise RuntimeError('Command: (%s) failed: (%s) %s' % (cmd, status, output))

    def _umount(self, path):
        if os.uname()[0] == 'SunOS':
            cmd = ['/usr/sbin/umount', path]
        else:
            cmd = ['/bin/umount', path]
        subprocess.call(cmd)

    def prepare(self):
        Fetcher.prepare(self)
        self.srcdir = tempfile.mkdtemp(prefix='xenpvboot.', dir=self.tmpdir)
        if self.location.startswith('nfs:'):
            self._mount(self.location[4:], self.srcdir, '-o ro')
        else:
            if stat.S_ISBLK(os.stat(self.location)[stat.ST_MODE]):
                option = '-o ro'
            else:
                option = '-o ro,loop'
            if os.uname()[0] == 'SunOS':
                options += ' -F hsfs'
            self._mount(self.location, self.srcdir, option)

    def cleanup(self):
        self._umount(self.srcdir)
        try:
            os.rmdir(self.srcdir)
        except:
            pass


class NFSISOFetcher(MountedFetcher):
    def prepare(self):
        Fetcher.prepare(self)
        self.nfsdir = tempfile.mkdtemp(prefix='xenpvboot.', dir=self.tmpdir)
        self.srcdir = tempfile.mkdtemp(prefix='xenpvboot.', dir=self.tmpdir)
        nfs = os.path.dirname(self.location[8:])
        iso = os.path.basename(self.location[8:])
        self._mount(nfs, self.nfsdir, '-o ro')
        option = '-o ro,loop'
        if os.uname()[0] == 'SunOS':
            options += ' -F hsfs'
        self._mount(os.path.join(self.nfsdir, iso), self.srcdir, option)

    def cleanup(self):
        MountedFetcher.cleanup(self)
        time.sleep(1)
        self._umount(self.nfsdir)
        try:
            os.rmdir(self.nfsdir)
        except:
            pass


class TFTPFetcher(Fetcher):
    def get_file(self, filename):
        if '/' in self.location[7:]:
            host = self.location[7:].split('/', 1)[0].replace(':', ' ')
            basedir = self.location[7:].split('/', 1)[1]
        else:
            host = self.location[7:].replace(':', ' ')
            basedir = ''
        suffix = ''.join(random.sample(string.ascii_letters, 6))
        local_name = os.path.join(self.tmpdir, 'xenpvboot.%s.%s' % (os.path.basename(filename), suffix))
        cmd = '/usr/bin/tftp %s -c get %s %s' % (host, os.path.join(basedir, filename), local_name)
        (status, output) = commands.getstatusoutput(cmd)
        if status != 0:
            raise RuntimeError('Command: (%s) failed: (%s) %s' % (cmd, status, output))
        return local_name


def main():
    usage = '''%prog [options] <location>

Get boot images from the given location and prepare for Xen to use.

Supported locations:

 - http://host/path
 - https://host/path
 - ftp://host/path
 - file:///path
 - tftp://host/path
 - nfs:host:/path
 - /path
 - /path/file.iso
 - /path/filesystem.img
 - /dev/hda1
 - nfs+iso:host:/path/file.iso
 - nfs+iso:host:/path/filesystem.img'''
    version = '%prog version 0.1'
    parser = OptionParser(usage=usage, version=version)
    parser.add_option('', '--kernel', action='store',
                      help='The kernel image file path relative to location.')
    parser.add_option('', '--ramdisk', action='store',
                      help='The initial ramdisk file path relative to location.')
    parser.add_option('', '--args', action='store',
                      help='Arguments pass to the kernel.')
    parser.add_option('', '--output', action='store',
                      help='Redirect output to this file instead of stdout.')
    parser.add_option('-q', '--quiet', action='store_true',
                      help='Be quiet.')
    (opts, args) = parser.parse_args()

    if len(args) < 1:
        parser.print_help(sys.stderr)
        sys.exit(1)

    location = args[0]
    if location.startswith('http://') or \
       location.startswith('https://') or \
       location.startswith('ftp://') or \
       location.startswith('file://') or \
       (os.path.exists(location) and os.path.isdir(location)):
        fetcher = Fetcher(location)
    elif location.startswith('nfs:') or \
         (os.path.exists(location) and not os.path.isdir(location)):
        fetcher = MountedFetcher(location)
    elif location.startswith('nfs+iso:'):
        fetcher = NFSISOFetcher(location)
    elif location.startswith('tftp://'):
        fetcher = TFTPFetcher(location)
    else:
        if not opts.quiet:
            print >>sys.stderr, 'Unsupported location: %s' % location
        sys.exit(1)

    try:
        fetcher.prepare()
    except Exception, e:
        if not opts.quiet:
            print >>sys.stderr, str(e)
        fetcher.cleanup()
        sys.exit(1)

    try:
        kernel = None
        if opts.kernel:
            kernel = fetcher.get_file(opts.kernel)
        else:
            for (k, _) in xen_paths:
                try:
                    kernel = fetcher.get_file(k)
                except Exception, e:
                    if not opts.quiet:
                        print >>sys.stderr, str(e)
                    continue
                break

        if not kernel:
            if not opts.quiet:
                print >>sys.stderr, 'Cannot get kernel from loacation: %s' % location
            sys.exit(1)

        ramdisk = None
        if opts.ramdisk:
            ramdisk = fetcher.get_file(opts.ramdisk)
        else:
            for (_, r) in xen_paths:
                try:
                    ramdisk = fetcher.get_file(r)
                except Exception, e:
                    if not opts.quiet:
                        print >>sys.stderr, str(e)
                    continue
                break
    finally:
        fetcher.cleanup()

    sxp = 'linux (kernel %s)' % kernel
    if ramdisk:
        sxp += '(ramdisk %s)' % ramdisk
    if opts.args:
        sxp += '(args %s)' % opts.args

    if not opts.output or opts.output == '-':
        sys.stdout.write(sxp)
        sys.stdout.flush()
    else:
        f = open(opts.output, 'w')
        f.write(sxp)
        f.close()


if __name__ == '__main__':
    main()
