#!/usr/bin/env python3

# Copyright: 2026 Hector CAO <hector.cao@canonical.com>
# SPDX-License-Identifier: GPL-3.0-or-later

"""APT version 2 hook: parse and detect ubuntu-virt packages.

Parses APT DPkg::Pre-Install-Pkgs hook input (protocol version 2).
Detects whether ubuntu-virt or ubuntu-virt-hwe is being installed/removed.
"""

from __future__ import annotations

import os
import subprocess
import sys

def _read_hook_input() -> str:
    """Read hook input from APT_HOOK_INFO_FD or stdin."""
    info_fd = os.environ.get("APT_HOOK_INFO_FD")
    if info_fd is not None:
        try:
            fd = int(info_fd)
            with os.fdopen(fd, "r", encoding="utf-8", errors="replace", closefd=False) as f:
                return f.read()
        except (ValueError, OSError):
            pass
    return sys.stdin.read()


def _parse_v2_transactions(content: str) -> tuple[list[str], list[str]]:
    """Parse APT v2 protocol transaction lines.
    
    Format per line:
      PackageName [CurrentVersion Operator NewVersion [.deb_path]] [ACTION]
    
    Where:
      - Operator `>` means REMOVE
      - Operator `<` means INSTALL/CONFIGURE
      - ACTION (optional) can be **REMOVE**, **CONFIGURE**, etc.
    
    Returns (installs, removals).
    """
    installs = []
    removals = []
    
    for line in content.splitlines():
        line = line.strip()
        
        # Fast filter to skip line without space (not a transaction line)
        if " " not in line:
            continue

        # Parse transaction line: PackageName [version info] [action]
        tokens = line.split()
        if not tokens or len(tokens) != 5:
            continue
        
        package = tokens[0]

        # Detect action from operator or explicit action keyword
        is_remove = False
        is_install = False
        
        if len(tokens) > 1:
            # Check for operator (> = remove, < = install)
            for token in tokens[1:]:
                if token == ">":
                    is_remove = True
                elif token == "<":
                    is_install = True
                elif token.startswith("**"):
                    # Explicit action like **REMOVE**, **CONFIGURE**
                    if "REMOVE" in token or "PURGE" in token:
                        is_remove = True
                    elif "INSTALL" in token or "CONFIGURE" in token:
                        is_install = True

        # Classify package
        if is_remove:
            removals.append(package)
        elif is_install:
            installs.append(package)
    
    return (installs, removals)

# Source packages considered part of the virt stack
_VIRT_SOURCES = {"qemu", "edk2", "seabios", "libvirt"}

def _source_packages(pkg_names: list[str]) -> dict[str, str]:
    """Return a mapping of package → source package for each name in pkg_names.

    Uses dpkg-query in a single call.  Packages unknown to dpkg (not yet
    installed) are returned with source == pkg (dpkg uses the binary name as
    source when no Source field is present).
    """
    if not pkg_names:
        return {}
    result = subprocess.run(
        ["dpkg-query", "-W", "-f=${Package}\t${Source}\n"] + list(pkg_names),
        capture_output=True,
        text=True,
        check=False,
    )
    mapping: dict[str, str] = {}
    for line in result.stdout.splitlines():
        parts = line.split("\t", 1)
        if len(parts) == 2:
            pkg, src = parts
            # Source field may be empty when binary == source
            mapping[pkg] = src.split(" ")[0] if src.strip() else pkg
    # Fall back to binary name for packages dpkg doesn't know (e.g. being freshly installed)
    for pkg in pkg_names:
        mapping.setdefault(pkg, pkg)
    return mapping


def _counterpart(name: str) -> str:
    """Return the base↔hwe counterpart name."""
    if name.endswith("-hwe"):
        return name[:-4]
    return name + "-hwe"


def orphaned_removals(installs: list[str], removals: list[str]) -> list[str]:
    """Return packages being removed that belong to a virt source and whose
    counterpart is not being installed.

    Source-package membership is determined via dpkg-query so that e.g.
    ovmf (source: edk2) is correctly classified without relying on name prefixes.
    """
    all_packages = list(dict.fromkeys(installs + removals))
    sources = _source_packages(all_packages)

    install_set = set(installs)
    result = []
    for pkg in removals:
        src = sources.get(pkg, pkg)
        # Normalise: strip -hwe suffix from source name to get canonical source
        canonical_src = src[:-4] if src.endswith("-hwe") else src
        if canonical_src in _VIRT_SOURCES or src in _VIRT_SOURCES:
            if _counterpart(pkg) not in install_set:
                result.append(pkg)
    return result


def main() -> int:
    content = _read_hook_input()
    installs, removals = _parse_v2_transactions(content)

    # Skip if no virt switch/install is detected
    if "ubuntu-virt" not in installs and "ubuntu-virt-hwe" not in installs:
        return 0

    orphaned = orphaned_removals(installs, removals)
    if orphaned:
        print("\033[31mVirt: switch of the ubuntu-virt[-hwe] stack is detected and there are orphaned removals (no counterpart being installed):\033[0m")
        for pkg in orphaned:
            print(f"\033[31m  - {pkg}\033[0m")
        print("\033[31mVirt: Please install them back and use the ubuntu-virt-helper script to switch between ubuntu-virt[-hwe] packages.\033[0m")

    return 0

if __name__ == "__main__":
    # Ignore exceptions to avoid breaking apt transactions; just print a warning and exit cleanly
    try:
        main()
    except Exception as e:
        print(f"apt-hook: warning: unexpected error: {e}", file=sys.stderr)
        raise SystemExit(0)
