## This file is part of mlpy.
## Discrete Wavelet Transform (DWT).

## This is an implementation of Discrete Wavelet Transform described in:
## Prabakaran Subramani, Rajendra Sahu and Shekhar Verma.
## 'Feature selection using Haar wavelet power spectrum'.
## In BMC Bioinformatics 2006, 7:432.
    
## This code is written by Giuseppe Jurman, <jurman@fbk.eu> and Davide Albanese, <albanese@fbk.eu>.
## (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.

## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.

__all__ = ['Dwt', 'haar', 'haar_spectrum']

import math
from numpy import *

SQRT_2 = sqrt(2.0)
LOG_2  = log(2.0)


def haar(d):
    """
    Haar wavelet decomposition.
    """
    
    N = log(d.shape[0])
    n = int(ceil(N / LOG_2))
    two_n = 2**n
    
    dwt = zeros(two_n, dtype = float)
    dwt[0: d.shape[0]] = d
    
    for j in range(n, 0, -1):
        offset = two_n - 2**j
        dproc = dwt[offset::].copy()
        
        for i in range(dproc.shape[0] / 2):
            dwt[offset + i] = \
                       (dproc[2 * i] - dproc[2 * i + 1]) / SQRT_2
            
            dwt[offset + dproc.shape[0] / 2 + i] = \
                       (dproc[2 * i] + dproc[2 * i + 1]) / SQRT_2
            
    return dwt[::-1]


def haar_spectrum(dwt):
    """
    Compute spectrum from wavelet decomposition.
    """
    
    N = log(dwt.shape[0])
    n = int(N / LOG_2)
    
    spec = zeros(n + 1, dtype = float)
    
    spec[0] = dwt[0] * dwt[0]
    if(dwt[0] < 0.0):
        spec[0] = -spec[0]
        
    for j in range(1, n + 1):
        spec[j] = sum(dwt[2**(j - 1): 2**j]**2)
        
    return spec


def rpv(s1, s2):
    """
    Relative Percentage Variation (RPV).
    """
    
    mean_s1 = mean(s1)
    mean_s2 = mean(s2)
    return (mean_s1 - mean_s2) / mean_s1 * 100


def arpv(s1, s2):
    """
    Absolute Relative Percentage Variation (ARPV).
    """
    
    return sqrt(abs(rpv(s1, s2)) * abs(rpv(s2, s1)))
    

def crpv(s1, s2, f, y):
    """
    Correlation Relative Percentage Variation (CRPV).
    """

    return arpv(s1, s2) * abs(correlate(f, y))
    

def compute_dwt(x, y, specdiff = 'rpv'):
    """
    Compute DWT.
    """
   
    pidx = where(y ==  1)
    nidx = where(y == -1)
    
    w = zeros(x.shape[1], dtype = float)
	
    for f in range(x.shape[1]):
        fp = x[pidx, f][0]
        fn = x[nidx, f][0]
        
        phaar = haar(fp)
        nhaar = haar(fn)
        
        s1 = haar_spectrum(phaar)
        s2 = haar_spectrum(nhaar)

        if specdiff == 'rpv':
            w[f] = rpv(s1, s2)

        elif specdiff == 'arpv':
            w[f] = arpv(s1, s2)

        elif specdiff == 'crpv':
            w[f] = crpv(s1, s2, x[:, f], y)

    return w


class Dwt:
    """Discrete Wavelet Transform (DWT).

    Example:
    
    >>> import numpy as np
    >>> import mlpy
    >>> xtr = np.array([[1.0, 2.0, 3.1, 1.0],  # first sample
    ...                 [1.0, 2.0, 3.0, 2.0],  # second sample
    ...                 [1.0, 2.0, 3.1, 1.0]]) # third sample
    >>> ytr = np.array([1, -1, 1])             # classes
    >>> mydwt = mlpy.Dwt()                   # initialize dwt class
    >>> mydwt.weights(xtr, ytr)              # compute weights on training data
    array([ -2.22044605e-14,  -2.22044605e-14,   6.34755463e+00,  -3.00000000e+02])
    """

    SPECDIFFS = ['rpv', 'arpv', 'crpv']

    def __init__(self, specdiff = 'rpv'):
        """Initialize the Dwt class.

        Input
        
          * *specdiff* - [string] spectral difference method ('rpv', 'arpv', 'crpv')
        """
        
        
        if not specdiff in self.SPECDIFFS:
            raise ValueError("specdiff (spectral difference) must be in %s" % self.SPECDIFFS)
        
        self.__specdiff = specdiff
        self.__classes  = None

    def weights(self, x, y):
        """Return ABSOLUTE feature weights.
        
        :Parameters:
          x : 2d ndarray float (samples x feats)
            training data
          y : 1d ndarray integer (-1 or 1)
            classes
        
        :Returns:
          fw :  1d ndarray float
            feature weights
        """


        self.__classes = unique(y)

        if self.__classes.shape[0] != 2:
            raise ValueError("DTW algorithm works only for two-classes problems")

        if self.__classes[0] != -1 or self.__classes[1] != 1:
            raise ValueError("DTW algorithm works only for 1 and -1 classes")

        w = compute_dwt(x, y, self.__specdiff)

        return w
    

    
