#!/usr/bin/env python

'''this file provite page table class for dump file (x86_32 without PAE)'''

import struct

class BasePageTable:
    def load_data(self, vbase, size):
        data = ''
        old_pfn = None
        for i in range(size):
            vaddr = vbase + i
            pfn = vaddr / self.arch.page_size
            if pfn != old_pfn:
                data += self.dump.read_page(self.arch.maddr_to_mfn(self.v2m(vaddr)))
            old_pfn = pfn

        offset = self.arch.page_offset(vbase)
        return data[offset:offset+size]

    def virt_to_machine(self, vaddr):
        raise NotImplementedError

    def v2m(self, vaddr):
        raise NotImplementedError

    def get_virt_to_machine(self):
        raise NotImplementedError
        

class PageTable_x86_32(BasePageTable):
    '''page table class for x86_32 without PAE'''
    MASK_L1 = 0xffc00000L
    MASK_L2 = 0x003ff000L
    OFFSET_L1 = 0x003fffffL
    OFFSET_L2 = 0x00000fffL
    BASE_4M = 0xffc00000L
    BASE_4K = 0xfffff000L
    FLAG4MB = 0x00000080L
    FLAGPRESENT = 0x00000001L
    PAGESHIFT = 12

    def __init__(self, dump, mfn):
        '''create new pagetable from addressed page mfn'''
        self.dump = dump        # CoreDump obj
        self.arch = dump.arch
        self.mfn = mfn           # mfn of l1 page table
        self.l1 = self.__read_pt(self.mfn, 1)
        if self.arch.name != 'x86_32':
            raise NotImplementedError   # x86_32 only.

    def virt_to_machine(self, vaddr):
        '''return machine address from vaddr'''
        pde = self.l1[vaddr & self.MASK_L1]
        
        if pde & self.FLAG4MB:       # 4MB page
            return pde & self.BASE_4M | vaddr & self.OFFSET_L1
        else:                   # 4kB page
            mfn = self.arch.maddr_to_mfn(pde & self.BASE_4K)
            pt = self.__read_pt(mfn, 2)
            pte = pt[vaddr & self.MASK_L2]
            return pte & self.BASE_4K | vaddr & self.OFFSET_L2

    v2m = virt_to_machine       # alias

    def __get_present_virt(self):
        '''return all pages with present flag'''

        ranges = []
        for page in self.l1:
            pde = self.l1[page]
            if pde & self.FLAGPRESENT:
                if pde & self.FLAG4MB:
                    ranges.append((page, pde & self.BASE_4M, 1024))
                else:
                    ptmfn = self.arch.maddr_to_mfn(pde & self.BASE_4K)
                    pt = self.__read_pt(ptmfn, 2)
                    for subpage in pt:
                        pte = pt[subpage]
                        if pte & self.FLAGPRESENT:
                            ranges.append((page|subpage, pte & self.BASE_4K, 1))
        return ranges

    def get_virt_to_machine(self):
        '''return virtual page to machine page mapping'''
        pages = self.__get_present_virt()
        virt2mach = {}
        for base, mbase, nr in pages:
            for i in range(nr):
                page = (base >> self.PAGESHIFT) + i
                mpage = (mbase >> self.PAGESHIFT) + i
                virt2mach[page] = mpage
        return virt2mach
                
    def __read_pt(self, mfn, level):
        '''return pagetable'''

        pt = struct.unpack( "L" * 1024,  self.dump.read_page(mfn) )

        if level == 1:
            shift = 22
        elif level == 2:
            shift = 12

        phash = {}
        for (page, pte) in enumerate(pt):
            if pte & self.FLAGPRESENT:
                phash[long(page) << shift] = pte

        return phash

class PageTable_x86_32pae(BasePageTable):
    '''page table class for x86_32 with PAE'''

    MASK_L0 = 0xc0000000L               # address mask
    MASK_L1 = 0x3fe00000L
    MASK_L2 = 0x001ff000L
    OFFSET_L1 =   0x00000000001fffffL             # offset mask
    OFFSET_L2 =   0x0000000000000fffL
    BASE_4K =     0x0000000ffffff000L              # base address mask
    BASE_2M =     0x0000000fffe00000L
    FLAG2MB =     0x0000000000000080L
    FLAGPRESENT = 0x0000000000000001L
    PAGESHIFT = 12

    def __init__(self, dump, mfn):
        '''create new pagetable from addressed page mfn'''
        self.dump = dump                # CoreDump obj
        self.arch = dump.arch
        self.mfn = mfn                  # mfn of l1 page table
        self.l0 = self.__read_pt(self.mfn, 0)
        if self.arch.name != 'x86_32pae':
            raise NotImplementedError   # x86_32pae only.

    def virt_to_machine(self, vaddr):
        '''return machine address from vaddr'''
        pdpte = self.l0[vaddr & self.MASK_L0]
        pdmfn = self.arch.maddr_to_mfn(pdpte & self.BASE_4K)
        pd = self.__read_pt(pdmfn, 1)
        pde = pd[vaddr & self.MASK_L1]


        if pde & self.FLAG2MB:          # 2MB page
            return pde & self.BASE_2M | vaddr & self.OFFSET_L1
        else:                           # 4kB page
            ptmfn = self.arch.maddr_to_mfn(pde & self.BASE_4K)
            pt = self.__read_pt(ptmfn, 2)
            pte = pt[vaddr & self.MASK_L2]
            return pte & self.BASE_4K | vaddr & self.OFFSET_L2

    v2m = virt_to_machine               # alias

    def __get_present_virt(self):
        '''return all pages with present flag'''

        ranges = []
        for pdpteaddr in self.l0:
            pdpte = self.l0[pdpteaddr]
            if pdpte & self.FLAGPRESENT:
                pdmfn = self.arch.maddr_to_mfn(pdpte & self.BASE_4K)
                pd = self.__read_pt(pdmfn, 1)
                for pdeaddr in pd:
                    pde = pd[pdeaddr]
                    if pde & self.FLAGPRESENT:
                        if pde & self.FLAG2MB:
                            ranges.append((pdpteaddr|pdeaddr, pde & self.BASE_4K, 512))
                        else:
                            ptmfn = self.arch.maddr_to_mfn(pde & self.BASE_4K)
                            pt = self.__read_pt(ptmfn, 2)
                            for pteaddr in pt:
                                pte = pt[pteaddr]
                                if pte & self.FLAGPRESENT:
                                    ranges.append((pdpteaddr|pdeaddr|pteaddr, pte & self.BASE_4K, 1))
        return ranges

    def get_virt_to_machine(self):
        '''return virtual page to machine page mapping'''
        pages = self.__get_present_virt()
        virt2mach = {}
        for base, mbase, nr in pages:
            for i in range(nr):
                page = (base >> self.PAGESHIFT) + i
                mpage = (mbase >> self.PAGESHIFT) + i
                virt2mach[page] = mpage
        return virt2mach
                
    def __read_pt(self, mfn, level):
        '''return pagetable'''

        if level == 0:
            entries = 4
        else:
            entries = 512
        pt = struct.unpack( "Q" * 512,  self.dump.read_page(mfn) )[:entries]

        if level == 0:
            shift = 30
        if level == 1:
            shift = 21
        elif level == 2:
            shift = 12

        phash = {}
        for (page, pte) in enumerate(pt):
            if pte & self.FLAGPRESENT:
                phash[long(page) << shift] = pte

        return phash

arch = {'x86_32': PageTable_x86_32,
        'x86_32pae': PageTable_x86_32pae,
        'x86_64': None,
        'ia64': None,
        }

def PageTable(dump, mfn):
    return arch[dump.arch.name](dump, mfn)


if __name__ == '__main__':
    import ElfCore
    xendump = ElfCore.ElfCoreReader('dump', 'x86_32')
    #pt = PageTable(xendump, 0x176000 >> 12)
    pt = PageTable(xendump, 0x7246)
    #print hex(pt.v2m(0xff177440L))    # dom0
    vfns = pt.__get_present_virt()
    p = [(hex(virt),hex(mach),pages) for virt,mach,pages in vfns]
    for line in p:
        print line
    #print pt.get_machine_to_virt()
