//! Output a range of IP addresses.
/*
 * Copyright (c) 2021  Peter Pentchev <roam@ringlet.net>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

use std::iter::{FusedIterator, StepBy};
use std::ops::RangeInclusive;

use crate::defs::{AddrExclude, AddrFormat, Config, Error};

use anyhow::anyhow;

fn format_dot(addr: u32) -> String {
    format!(
        "{first}.{second}.{third}.{fourth}",
        first = addr >> 24_i32,
        second = (addr >> 16_i32) & 255_u32,
        third = (addr >> 8_i32) & 255_u32,
        fourth = addr & 255_u32,
    )
}

fn format_dec(addr: u32) -> String {
    format!("{addr}")
}

fn format_hex(addr: u32) -> String {
    format!("{addr:x}")
}

/// An iterator that returns the next address string to output.
pub struct OutputRange<'fmt, I: Iterator<Item = u32>> {
    /// The function to use to format each address as a string.
    formatter: &'fmt dyn Fn(u32) -> String,

    /// The next address to output, if any.
    addr_iter: I,

    done: bool,
}

impl<I: Iterator<Item = u32>> Iterator for OutputRange<'_, I> {
    type Item = String;

    #[inline]
    fn next(&mut self) -> Option<Self::Item> {
        if self.done {
            return None;
        }

        if let Some(addr) = self.addr_iter.next() {
            Some((self.formatter)(addr))
        } else {
            self.done = true;
            None
        }
    }
}

impl<I: Iterator<Item = u32>> FusedIterator for OutputRange<'_, I> {}

#[derive(Debug)]
pub struct AddrExcludedIter<'cfg, I: Iterator<Item = u32>> {
    addr_iter: I,

    exclude: &'cfg Option<Vec<AddrExclude>>,

    done: bool,
}

impl<I: Iterator<Item = u32>> Iterator for AddrExcludedIter<'_, I> {
    type Item = u32;

    fn next(&mut self) -> Option<Self::Item> {
        if self.done {
            return None;
        }

        loop {
            if let Some(value) = self.addr_iter.next() {
                if let Some(ref exclude) = *self.exclude {
                    if !exclude
                        .iter()
                        .any(|excl| (value & excl.mask()) == excl.value())
                    {
                        return Some(value);
                    }
                } else {
                    return Some(value);
                }
            } else {
                self.done = true;
                return None;
            }
        }
    }
}

impl<I: Iterator<Item = u32>> FusedIterator for AddrExcludedIter<'_, I> {}

/// Output a range of IP addresses.
pub fn output_range(
    cfg: &Config,
) -> OutputRange<'_, AddrExcludedIter<'_, StepBy<RangeInclusive<u32>>>> {
    let formatter: &dyn Fn(u32) -> String = match cfg.format {
        AddrFormat::Dot => &format_dot,
        AddrFormat::Dec => &format_dec,
        AddrFormat::Hex => &format_hex,
    };
    OutputRange {
        formatter,
        addr_iter: AddrExcludedIter {
            addr_iter: (cfg.range.start..=cfg.range.end).step_by(cfg.step),
            exclude: &cfg.exclude,
            done: false,
        },
        done: false,
    }
}

/// Convert a range into CIDR form.
pub fn cidrize(cfg: &Config) -> Result<String, Error> {
    let diff = cfg.range.start ^ cfg.range.end;
    let (base, prefixlen) = if diff == 0 {
        (cfg.range.start, 32_u32)
    } else {
        let offset = diff.ilog2();
        if offset == 31 {
            (0, 0)
        } else {
            let mask = u32::MAX
                .checked_shl(offset)
                .ok_or_else(|| Error::Internal(anyhow!("Could not compute a mask for {offset}")))?;
            let prefixlen = 31_u32.checked_sub(offset).ok_or_else(|| {
                Error::Internal(anyhow!(
                    "Got an offset larger than 31 for {start}..{end}",
                    start = cfg.range.start,
                    end = cfg.range.end,
                ))
            })?;
            (cfg.range.start & mask, prefixlen)
        }
    };
    Ok(format!("{base}/{prefixlen}", base = format_dot(base)))
}

#[cfg(test)]
#[allow(clippy::panic_in_result_fn)]
mod tests {
    use anyhow::{Context as _, Result};
    use itertools::Itertools as _;
    use rstest::rstest;

    use crate::defs::{AddrFormat, AddrRange, Config};
    use crate::parse;

    #[rstest]
    #[case(AddrFormat::Dot, 0x11, 0x13, 1, None, "0.0.0.17 0.0.0.18 0.0.0.19")]
    #[case(AddrFormat::Dec, 0x11, 0x13, 1, None, "17 18 19")]
    #[case(AddrFormat::Hex, 0x11, 0x13, 1, None, "11 12 13")]
    #[case(AddrFormat::Dot, 0x11, 0x13, 2, None, "0.0.0.17 0.0.0.19")]
    #[case(AddrFormat::Dec, 0x11, 0x13, 2, None, "17 19")]
    #[case(AddrFormat::Hex, 0x11, 0x13, 2, None, "11 13")]
    #[case(
        AddrFormat::Dot,
        0x11,
        0x14,
        1,
        None,
        "0.0.0.17 0.0.0.18 0.0.0.19 0.0.0.20"
    )]
    #[case(AddrFormat::Dec, 0x11, 0x14, 1, None, "17 18 19 20")]
    #[case(AddrFormat::Hex, 0x11, 0x14, 1, None, "11 12 13 14")]
    #[case(AddrFormat::Dot, 0x11, 0x14, 2, None, "0.0.0.17 0.0.0.19")]
    #[case(AddrFormat::Dec, 0x11, 0x14, 2, None, "17 19")]
    #[case(AddrFormat::Hex, 0x11, 0x14, 2, None, "11 13")]
    #[case(AddrFormat::Dot, 0x11, 0x13, 1, Some("...17"), "0.0.0.18 0.0.0.19")]
    #[case(AddrFormat::Dot, 0x11, 0x13, 1, Some("...17,19"), "0.0.0.18")]
    #[case(AddrFormat::Dot, 0x11, 0x13, 1, Some("..0"), "")]
    fn test_output_range(
        #[case] format: AddrFormat,
        #[case] start: u32,
        #[case] end: u32,
        #[case] step: usize,
        #[case] exclude_pattern: Option<&str>,
        #[case] expected: &str,
    ) -> Result<()> {
        let exclude = if let Some(pattern) = exclude_pattern {
            Some(parse::parse_exclude(pattern)?)
        } else {
            None
        };
        let cfg = Config {
            delim: '?',
            exclude,
            format,
            range: AddrRange { start, end },
            step,
        };
        let res = super::output_range(&cfg).join(" ");
        assert_eq!(res, expected);
        Ok(())
    }

    #[rstest]
    #[case(0x10, 0x1F, "0.0.0.16/28")]
    #[case(0x7F_00_00_00, 0x7F_00_01_FF, "127.0.0.0/23")]
    fn test_cidrize(#[case] start: u32, #[case] end: u32, #[case] expected: &str) -> Result<()> {
        let cfg = Config {
            delim: '@',
            exclude: None,
            format: AddrFormat::Hex,
            range: AddrRange { start, end },
            step: 16,
        };
        assert_eq!(
            super::cidrize(&cfg).with_context(|| format!("Could not cidrize {cfg:?}"))?,
            expected
        );
        Ok(())
    }
}
