#! /usr/bin/python3
# AWS EC2 HibInit Agent. This agent does several things upon startup:
# 1. It checks for sufficient swap space to allow hibernation and fails if not enough space
# 2. If there's no swap file or the existing swap file isn't of a sufficient size, a swap file is created
#     1. If `touch-swap` is enabled, all the swap file's blocks will be touched
#         so that the root EBS volume is pre-warmed.
# 3. It updates the offset of the swap file in the kernel using `snapshot_set_swap_area` ioctl.
# 4. It updates the resume offset and resume device in grub file.
#
# This file is compatible both with Python 2 and Python 3
import argparse
import array
import ctypes as ctypes
import fcntl
import mmap
import os
import struct
import sys
import syslog
import math
import requests
import signal
from subprocess import check_call, check_output

try:
    from ConfigParser import ConfigParser, NoSectionError, NoOptionError
except:
    from configparser import ConfigParser, NoSectionError, NoOptionError

GRUB2_DIR = '/etc/default/grub.d'

# Space reserved for swap headers
SWAP_RESERVED_SIZE = 16384
LOG_TO_SYSLOG = True

SWAP_FILE = '/swap-hibinit'

DEFAULT_STATE_DIR = '/var/lib/hibinit-agent'
HIB_ENABLED_FILE = "hibernation-enabled"
IMDS_BASEURL = 'http://169.254.169.254'
IMDS_API_TOKEN_PATH = 'latest/api/token'
IMDS_SPOT_ACTION_PATH = 'latest/meta-data/hibernation/configured'

# Default Values
KB = 1024
MAX_SWAP_SIZE_OFFSET_ALLOWED = 100 * KB
DEFAULT_SWAP_SIZE_MB = 4000
DEFAULT_SWAP_PERCENTAGE = 100

# Sigterm_handler Behaviour Modifiers
SHUTDOWN_REQUESTED = False


def print_to_sys_log(message):
    if LOG_TO_SYSLOG:
        syslog.syslog(message)


def critical_process_sigterm_handler(signum, frame):
    global SHUTDOWN_REQUESTED
    SHUTDOWN_REQUESTED = True


def default_sigterm_handler(signum, frame):
    check_swapon_cmd = "swapon -s | cut -f1,1 | grep -w {name}"
    check_swapon_cmd = check_swapon_cmd.format(name=SWAP_FILE)
    check_swapon = check_output(check_swapon_cmd, shell=True)
    if check_swapon.strip() == SWAP_FILE:
        swapoff_cmd = "swapoff {filename}"
        swapoff_cmd = swapoff_cmd.format(filename=SWAP_FILE)
        check_call(swapoff_cmd, shell=True)
    if os.path.isfile(SWAP_FILE) and os.access(SWAP_FILE, os.R_OK):
        os.remove(SWAP_FILE)
    exit(0)


def critical_process(function_with_critical_process):
    def wrapper(*args, **kwargs):
        signal.signal(signal.SIGTERM, critical_process_sigterm_handler)
        function_with_critical_process(*args, **kwargs)
        if SHUTDOWN_REQUESTED:
            exit(0)
        signal.signal(signal.SIGTERM, default_sigterm_handler)

    return wrapper


def fallocate(fl, size):
    try:
        _libc = ctypes.CDLL('libc.so.6')
        _fallocate = _libc.fallocate
        _fallocate.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_ulong]

        # (FD, mode, offset, len)
        res = _fallocate(fl.fileno(), 0, 0, size)
        if res != 0:
            raise Exception("Failed to perform fallocate(). Result: %d" % res)
    except Exception as e:
        print_to_sys_log("Failed to call fallocate(), will use resize. Err: %s" % str(e))
        fl.seek(size - 1)
        fl.write(chr(0))


def get_file_block_number(filename):
    with open(filename, 'r') as handle:
        buf = array.array('L', [0])
        # from linux/fs.h
        fibmap = 0x01
        result = fcntl.ioctl(handle.fileno(), fibmap, buf)
    if result < 0:
        raise Exception("Failed to get the file offset. Error=%d" % result)
    return buf[0]


def get_rootfs_size():
    stat = os.statvfs('/')
    return math.ceil(float(stat.f_bsize * stat.f_blocks) / (KB * KB * KB))


# This is only for grub2
def find_grub_mount():
    path_list = ['/etc/grub2-efi.cfg',
                 '/boot/grub2/grub.cfg',
                 '/boot/grub2-efi/grub.cfg',
                 '/etc/grub2.cfg',
                 '/boot/grub/grub.cfg']
    for ln in path_list:
        if os.path.isfile(ln) and os.access(ln, os.R_OK):
            find_mount_cmd = 'stat -L -c %m ' + ln
            mount = check_output(find_mount_cmd, shell=True, universal_newlines=True)
            return mount.strip()
    return None


def get_partuuid(device):
    return check_output(
        ['lsblk', '-dno', 'PARTUUID', device]).decode('ascii').strip()

def partuuid_to_device(partuuid):
    path = f"/dev/disk/by-partuuid/{partuuid}"
    if not os.path.exists(path):
        raise FileNotFoundError(f"{path} does not exist")

    return os.path.realpath(path)

def patch_grub_config(swap_device, offset, grub2_dir):
    mount_point = find_grub_mount()

    if mount_point is None:
        print_to_sys_log("Grub configuration is not updated. Grub cannot found under /boot or /etc. " +
                         "Please run manual grub update with resume parameters")
        return

    print_to_sys_log("Updating GRUB to use the device %s with offset %d for resume" % (swap_device, offset))

    # Do GRUB2 update as well
    """Update GRUB2 config when needed"""
    if grub2_dir and os.path.exists(grub2_dir):
        offset_file = os.path.join(grub2_dir, '99-set-swap.cfg')
        resume_swap_device=swap_device
        if swap_device.startswith("/dev"):
            partuuid = get_partuuid(swap_device)
            swap_device = "PARTUUID=%s" % get_partuuid(swap_device)
            resume_swap_device= partuuid_to_device(partuuid)
        grub_snippet = (
            'GRUB_CMDLINE_LINUX_DEFAULT="$GRUB_CMDLINE_LINUX_DEFAULT'
            ' no_console_suspend=1 resume_offset=%d resume=%s"\n'
            % (offset, swap_device))
        if ((not os.path.exists(offset_file)
             or open(offset_file).read() not in (grub_snippet))):
            print_to_sys_log("Updating GRUB to use the device %s with offset %d for resume"
                % (swap_device, offset))
            with open(offset_file, 'w') as fl:
                fl.write(grub_snippet)
            check_call('/usr/sbin/update-grub2')
            print_to_sys_log("GRUB configuration is updated")

            fsfreeze = "sync && mountpoint -q {mount} && trap '' HUP INT QUIT TERM && " + \
                       "fsfreeze -f {mount} && fsfreeze -u {mount}"
            fsfreeze = fsfreeze.format(mount=mount_point)
            check_call(fsfreeze, shell=True)

            # Some grubby versions need a restart after changing in kernel parameters
            # echo offset and swap device helps the customer to use the agent immediately
            # after rpm installation.
            if os.path.exists("sys/power/resume"):
                echo_resume_device = "echo {resume_swap_device} > /sys/power/resume"
                echo_resume_device = echo_resume_device.format(resume_swap_device=resume_swap_device)
                print_to_sys_log("sys/power/resume exist and will be updated")
                check_call(echo_resume_device, shell=True)

            if os.path.exists("/sys/power/resume_offset"):
                echo_resume_offset = "echo {offset} > /sys/power/resume_offset"
                echo_resume_offset = echo_resume_offset.format(offset=offset)
                print_to_sys_log("sys/power/resume_offset exist and will be updated")
                check_call(echo_resume_offset, shell=True)

            print_to_sys_log("GRUB configuration is updated")


@critical_process
def update_kernel_swap_offset(config):
    swapon = config.swapon.format(swapfile=SWAP_FILE)
    print_to_sys_log("Running: %s" % swapon)
    check_call(swapon, shell=True)
    print_to_sys_log("Updating the kernel offset for the swapfile: %s" % SWAP_FILE)

    statbuf = os.stat(SWAP_FILE)
    dev = statbuf.st_dev
    if config.btrfs_enabled:
        get_offset = "btrfs inspect-internal map-swapfile -r {swapfile}".format(swapfile=SWAP_FILE)
        offset = int(check_output(get_offset, shell=True).decode(sys.getfilesystemencoding()))
    else:
        offset = get_file_block_number(SWAP_FILE)

    if config.grub_update:
        dev_str = find_device_for_file(SWAP_FILE)
        patch_grub_config(dev_str, offset, GRUB2_DIR)
    else:
        print_to_sys_log("Skipping GRUB configuration update")

    print_to_sys_log("Setting swap device to %d with offset %d" % (dev, offset))

    if not config.btrfs_enabled and os.path.exists("/dev/snapshot"):
        # Set the kernel swap offset, see https://www.kernel.org/doc/Documentation/power/userland-swsusp.txt
        # From linux/suspend_ioctls.h
        snapshot_set_swap_area = 0x400C330D
        buf = struct.pack('LI', offset, dev)
        with open('/dev/snapshot', 'r') as snap:
            fcntl.ioctl(snap, snapshot_set_swap_area, buf)
        print_to_sys_log("Done updating the swap offset. Turning swapoff")

    swapoff = config.swapoff.format(swapfile=SWAP_FILE)
    print_to_sys_log("Running: %s" % swapoff)
    check_call(swapoff, shell=True)


def find_device_for_file(filename):
    # Find the mount point for the swap file ('df -P /swap')
    df_out = check_output(['df', '-P', filename]).decode(sys.getfilesystemencoding())
    dev_str = df_out.split("\n")[1].split()[0]
    return dev_str


class SwapInitializer(object):
    def __init__(self, config):
        self.swap_size = config.swap_mb * KB * KB
        self.config = config
        self.block_size = KB * KB  # 1 MiB

    def init_swap(self):
        """
            Initialize the swap using direct IO to avoid polluting the page cache
        """
        if self.config.btrfs_enabled:
            self.setup_btrfs()

        self.do_allocate()

        if self.config.touch_swap:
            self.pretouch_swap()
        else:
            print_to_sys_log("Swap pre-heating is skipped, the swap blocks won't be touched during "
                             "to ensure they are ready")
        self.init_mkswap()

    def setup_btrfs(self):
        no_cow = "truncate -s 0 {swapfile} && chattr +C {swapfile}".format(swapfile=SWAP_FILE)
        print_to_sys_log("Setting /swap to No Copy On Write")
        check_call(no_cow, shell=True)

    def do_allocate(self):
        print_to_sys_log("Allocating %d bytes in %s" % (self.swap_size, SWAP_FILE))
        with open(SWAP_FILE, 'w+') as fl:
            fallocate(fl, self.swap_size)
        os.chmod(SWAP_FILE, 0o600)  # Read + Write Permissions

    def pretouch_swap(self):
        written = 0
        print_to_sys_log("Opening %s for direct IO" % SWAP_FILE)
        fd = os.open(SWAP_FILE, os.O_RDWR | os.O_DIRECT | os.O_SYNC | os.O_DSYNC)
        if fd < 0:
            raise Exception("Failed to open swap file. Err: %s" % os.strerror(os.errno))

        filler_block = None
        try:
            # Create a filler block that is correctly aligned for direct IO
            filler_block = mmap.mmap(-1, self.block_size)
            # We're using 'b' to avoid optimizations that might happen for zero-filled pages
            filler_block.write(b'b' * self.block_size)

            print_to_sys_log("Touching all blocks in %s" % SWAP_FILE)

            while written < self.swap_size:
                res = os.write(fd, filler_block)
                if res <= 0:
                    raise Exception("Failed to touch a block. Err: %s" % os.strerror(os.errno))
                written += res
        finally:
            os.close(fd)
            if filler_block:
                filler_block.close()
        print_to_sys_log("Swap file %s is ready" % SWAP_FILE)

    def init_mkswap(self):
        # Do mkswap
        try:
            mkswap = self.config.mkswap.format(swapfile=SWAP_FILE)
            print_to_sys_log("Running: %s" % mkswap)
            check_call(mkswap, shell=True)
        except Exception as e:
            raise Exception(("Failed to setup swap area, reason: %s" % str(e)))


def identify_file_system(swapfile, file_system):
    # Walk the parent directories of the swapfile to find on which
    # filesystem it's mounted
    swap_place = swapfile
    dev = None
    while not dev:
        swap_place, _ = os.path.split(swap_place)
        try:
            dev = find_device_for_file(swap_place)
        except:
            if swap_place == '/':
                raise Exception("Failed to find the filesystem type of %s" % swapfile)

    with open("/proc/mounts") as fl:
        lines = fl.read().split("\n")
        for ln in lines:
            if dev in ln and file_system in ln:
                return True
    return False


class Config(object):
    def __init__(self, config_file, args):
        def get_from_config(section, name):
            try:
                return config_file.get(section, name)
            except NoSectionError:
                return None
            except NoOptionError:
                return None

        def get_int_from_config(section, name):
            v = get_from_config(section, name)
            return None if v is None else int(v)

        self.log_to_syslog = self.merge(
            self.to_bool(args.log_to_syslog),
            self.to_bool(get_from_config('core', 'log-to-syslog')),
            True)
        self.grub_update = self.merge(
            self.to_bool(args.grub_update),
            self.to_bool(get_from_config('core', 'grub-update')),
            True)

        self.touch_swap = self.merge(
            self.to_bool(args.touch_swap),
            self.to_bool(get_from_config('core', 'touch-swap')),
            identify_file_system(SWAP_FILE, "xfs"))
        self.btrfs_enabled = self.merge(
            self.to_bool(args.btrfs_enabled),
            self.to_bool(get_from_config('core', 'btrfs-enabled')),
            identify_file_system(SWAP_FILE, "btrfs"))
        self.state_dir = self.merge(None, get_from_config('core', 'state-dir'), DEFAULT_STATE_DIR)

        self.swap_percentage = self.merge(args.swap_ram_percentage, get_int_from_config('swap', 'percentage-of-ram'), DEFAULT_SWAP_PERCENTAGE)
        self.swap_mb = self.merge(args.swap_target_size_mb, get_int_from_config('swap', 'target-size-mb'), DEFAULT_SWAP_SIZE_MB)
        self.mkswap = self.merge(args.mkswap, get_from_config('swap', 'mkswap'), 'mkswap {swapfile}')
        self.swapon = self.merge(args.swapon, get_from_config('swap', 'swapon'), 'swapon {swapfile}')
        self.swapoff = self.merge(args.swapoff, get_from_config('swap', 'swapoff'), 'swapoff {swapfile}')

    def merge(self, arg_value, cf_value, def_val):
        if arg_value is not None:
            return arg_value
        if cf_value is not None:
            return cf_value
        return def_val

    def to_bool(self, bool_str):
        """Parse the string and return a boolean value, None, or raise an exception"""
        if bool_str is None:
            return None
        if bool_str.lower() in ['true', 't', '1']:
            return True
        elif bool_str.lower() in ['false', 'f', '0']:
            return False
        # if here we couldn't parse it
        raise ValueError("%s is not recognized as a boolean value" % bool_str)

    def __str__(self):
        return str(self.__dict__)


def get_imds_token(seconds=21600):
    """ Get a token to access instance metadata. """
    print_to_sys_log("Requesting new IMDSv2 token.")
    request_header = {'X-aws-ec2-metadata-token-ttl-seconds': '{}'.format(seconds)}
    token_url = '{}/{}'.format(IMDS_BASEURL, IMDS_API_TOKEN_PATH)
    response = requests.put(token_url, headers=request_header)
    response.close()
    if response.status_code != 200:
        return None

    return response.text


def create_state_dir(state_dir):
    """ Create agent run dir if it doesn't exist."""
    if not os.path.isdir(state_dir):
        os.makedirs(state_dir)


def hibernation_enabled(state_dir):
    """Returns a boolean indicating whether hibernation is enabled or not.

    Hibernation can't be enabled/disabled after the instance launch. If we find
    hibernation to be enabled, we create a semaphore file so that we don't
    have to probe IMDS again. That is useful when an instance is rebooted
    after/if the IMDS http endpoint has been disabled.
    """
    hib_sem_file = os.path.join(state_dir, HIB_ENABLED_FILE)
    if os.path.isfile(hib_sem_file):
        print_to_sys_log("Found {!r}, configuring hibernation".format(hib_sem_file))
        return True

    imds_token = get_imds_token()
    if imds_token is None:
        # IMDS http endpoint is disabled
        return False

    request_header = {'X-aws-ec2-metadata-token': imds_token}
    response = requests.get("{}/{}".format(IMDS_BASEURL, IMDS_SPOT_ACTION_PATH),
                            headers=request_header)
    response.close()
    if response.status_code != 200 or response.text.lower() == "false":
        return False

    print_to_sys_log("Hibernation Configured Flag found")
    os.mknod(hib_sem_file)
    return True


# Returns a tuple denoting (valid_to_init_hibernation: bool, new_swap_file_required: bool)
def validate_system_requirements(config):
    target_swap_size = config.swap_mb * KB * KB  # Converting to bytes
    # Validate if disk space>total RAM
    ram_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
    if get_rootfs_size() <= (math.ceil(float(ram_bytes) / (KB * KB * KB))):
        print_to_sys_log(
            "Insufficient disk space. Cannot create setup for hibernation. Please allocate a larger root device")
        return False, False

    swap_percentage_size = ram_bytes * config.swap_percentage // 100
    if swap_percentage_size > target_swap_size:
        target_swap_size = int(swap_percentage_size)
    print_to_sys_log("Will check if swap is at least: %d megabytes" % (target_swap_size // (KB * KB)))

    # Validate if swap file exists
    cur_swap = 0
    if os.path.isfile(SWAP_FILE) and os.access(SWAP_FILE, os.R_OK):
        cur_swap = os.path.getsize(SWAP_FILE)

    if cur_swap > target_swap_size - SWAP_RESERVED_SIZE + MAX_SWAP_SIZE_OFFSET_ALLOWED:
        print_to_sys_log("Swap already exists! (have %d, need %d), deleting existing swap file %s since current swap is "
                         "sufficiently large and wasting disk space" % (cur_swap, target_swap_size, SWAP_FILE))
        os.remove(SWAP_FILE)
    elif cur_swap >= target_swap_size - SWAP_RESERVED_SIZE:
        print_to_sys_log("There's sufficient swap available (have %d, need %d)" %
                         (cur_swap, target_swap_size))
        return True, False
    # Validate if instance was launched from pre-created image and swap size>=total RAM, if not re-create the swap
    elif cur_swap > 0 and (cur_swap < target_swap_size - SWAP_RESERVED_SIZE):
        print_to_sys_log("Swap already exists! (have %d, need %d), deleting existing swap file %s" %
                         (cur_swap, target_swap_size, SWAP_FILE))
        os.remove(SWAP_FILE)

    # We need to create swap, but first validate that we have enough free space
    swap_dev = os.path.dirname(SWAP_FILE)
    st = os.statvfs(swap_dev)
    free_bytes = st.f_bavail * st.f_frsize
    # Rounding off to swap_size+10mb for swap headers
    free_space_needed = target_swap_size + 10 * KB * KB
    if free_space_needed >= free_bytes:
        print_to_sys_log("There's not enough space (%d present, %d needed) for the swap file on the device: %s" % (
            free_bytes, free_space_needed, swap_dev))
        return False, True
    print_to_sys_log("There's enough space (%d present, %d needed) for the swap file on the device: %s" % (
        free_bytes, free_space_needed, swap_dev))
    return True, True


def main():
    # Parse arguments
    parser = argparse.ArgumentParser(
        description="An EC2 background process that creates a setup for instance hibernation "
                    "at instance launch and also registers ACPI sleep event/actions")
    parser.add_argument('-c', '--config', help='Configuration file to use', type=str)
    parser.add_argument("-syslog", "--log-to-syslog", help='Log to syslog', type=str)
    parser.add_argument("-grub", "--grub-update", help='Update GRUB config with resume offset', type=str)
    parser.add_argument("-touch", "--touch-swap", help='Do swap initialization', type=str)
    parser.add_argument("-btrfs", "--btrfs-enabled", help='Sets no copy on write on swap file', type=str)
    parser.add_argument("-p", "--swap-ram-percentage", help='The target swap size as a percentage of RAM', type=int)
    parser.add_argument("-s", "--swap-target-size-mb", help='The target swap size in megabytes', type=int)
    parser.add_argument('--mkswap', help='The command line utility to set up swap', type=str)
    parser.add_argument('--swapon', help='The command line utility to turn on swap', type=str)
    parser.add_argument('--swapoff', help='The command line utility to turn off swap', type=str)

    args = parser.parse_args()

    config_file = ConfigParser()
    if args.config:
        config_file.read(args.config)

    config = Config(config_file, args)
    global LOG_TO_SYSLOG
    LOG_TO_SYSLOG = config.log_to_syslog

    print_to_sys_log("Effective config: %s" % config)
    create_state_dir(config.state_dir)

    # Let's first check if we even need to run
    if not hibernation_enabled(config.state_dir):
        print_to_sys_log("Instance Launch has not enabled Hibernation Configured Flag. hibinit-agent exiting!!")
        exit(0)

    # Sets default termination handling
    signal.signal(signal.SIGTERM, default_sigterm_handler)

    valid_to_init_hibernation, new_swap_file_required = validate_system_requirements(config)

    if not valid_to_init_hibernation:
        exit(1)

    if new_swap_file_required:
        sw = SwapInitializer(config)
        try:
            sw.init_swap()
        except Exception as e:
            raise Exception("Failed to initialise swap file, reason: %s" % str(e))

    update_kernel_swap_offset(config)


if __name__ == '__main__':
    main()
